feat: 添加工具调用支持,优化消息处理和持久化
This commit is contained in:
parent
8bb32fa066
commit
ef601107ac
@ -211,6 +211,7 @@ fn chat_message_to_llm_message(m: &ChatMessage) -> Message {
|
|||||||
content,
|
content,
|
||||||
tool_call_id: m.tool_call_id.clone(),
|
tool_call_id: m.tool_call_id.clone(),
|
||||||
name: m.tool_name.clone(),
|
name: m.tool_name.clone(),
|
||||||
|
tool_calls: m.tool_calls.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -223,6 +224,12 @@ pub struct AgentLoop {
|
|||||||
max_iterations: usize,
|
max_iterations: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct AgentProcessResult {
|
||||||
|
pub final_response: ChatMessage,
|
||||||
|
pub emitted_messages: Vec<ChatMessage>,
|
||||||
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
pub fn new(provider_config: LLMProviderConfig) -> Result<Self, AgentError> {
|
||||||
let max_iterations = provider_config.max_tool_iterations;
|
let max_iterations = provider_config.max_tool_iterations;
|
||||||
@ -267,12 +274,13 @@ impl AgentLoop {
|
|||||||
/// it loops back to the LLM with the tool results until either:
|
/// it loops back to the LLM with the tool results until either:
|
||||||
/// - The LLM returns no more tool calls (final response)
|
/// - The LLM returns no more tool calls (final response)
|
||||||
/// - Maximum iterations are reached
|
/// - Maximum iterations are reached
|
||||||
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<ChatMessage, AgentError> {
|
pub async fn process(&self, mut messages: Vec<ChatMessage>) -> Result<AgentProcessResult, AgentError> {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
tracing::debug!(history_len = messages.len(), max_iterations = self.max_iterations, "Starting agent process");
|
||||||
|
|
||||||
// Track tool calls for loop detection
|
// Track tool calls for loop detection
|
||||||
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
let mut loop_detector = LoopDetector::new(LoopDetectorConfig::default());
|
||||||
|
let mut emitted_messages = Vec::new();
|
||||||
|
|
||||||
for iteration in 0..self.max_iterations {
|
for iteration in 0..self.max_iterations {
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -316,15 +324,23 @@ impl AgentLoop {
|
|||||||
// If no tool calls, this is the final response
|
// If no tool calls, this is the final response
|
||||||
if response.tool_calls.is_empty() {
|
if response.tool_calls.is_empty() {
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
return Ok(assistant_message);
|
emitted_messages.push(assistant_message.clone());
|
||||||
|
return Ok(AgentProcessResult {
|
||||||
|
final_response: assistant_message,
|
||||||
|
emitted_messages,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute tool calls
|
// Execute tool calls
|
||||||
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
tracing::info!(iteration, count = response.tool_calls.len(), "Tool calls detected, executing tools");
|
||||||
|
|
||||||
// Add assistant message with tool calls
|
// Add assistant message with tool calls
|
||||||
let assistant_message = ChatMessage::assistant(response.content.clone());
|
let assistant_message = ChatMessage::assistant_with_tool_calls(
|
||||||
|
response.content.clone(),
|
||||||
|
response.tool_calls.clone(),
|
||||||
|
);
|
||||||
messages.push(assistant_message.clone());
|
messages.push(assistant_message.clone());
|
||||||
|
emitted_messages.push(assistant_message);
|
||||||
|
|
||||||
// Execute tools and add results to messages
|
// Execute tools and add results to messages
|
||||||
let tool_results = self.execute_tools(&response.tool_calls).await;
|
let tool_results = self.execute_tools(&response.tool_calls).await;
|
||||||
@ -356,7 +372,8 @@ impl AgentLoop {
|
|||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
format!("{}\n\n[上一条结果]\n{}", msg, truncated_output),
|
||||||
);
|
);
|
||||||
messages.push(tool_message);
|
messages.push(tool_message.clone());
|
||||||
|
emitted_messages.push(tool_message);
|
||||||
}
|
}
|
||||||
LoopDetectionResult::Ok => {
|
LoopDetectionResult::Ok => {
|
||||||
let tool_message = ChatMessage::tool(
|
let tool_message = ChatMessage::tool(
|
||||||
@ -364,7 +381,8 @@ impl AgentLoop {
|
|||||||
tool_call.name.clone(),
|
tool_call.name.clone(),
|
||||||
truncated_output,
|
truncated_output,
|
||||||
);
|
);
|
||||||
messages.push(tool_message);
|
messages.push(tool_message.clone());
|
||||||
|
emitted_messages.push(tool_message);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -400,7 +418,11 @@ impl AgentLoop {
|
|||||||
match (*self.provider).chat(request).await {
|
match (*self.provider).chat(request).await {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let assistant_message = ChatMessage::assistant(response.content);
|
let assistant_message = ChatMessage::assistant(response.content);
|
||||||
Ok(assistant_message)
|
emitted_messages.push(assistant_message.clone());
|
||||||
|
Ok(AgentProcessResult {
|
||||||
|
final_response: assistant_message,
|
||||||
|
emitted_messages,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Fallback if summary call fails
|
// Fallback if summary call fails
|
||||||
@ -408,7 +430,11 @@ impl AgentLoop {
|
|||||||
let final_message = ChatMessage::assistant(
|
let final_message = ChatMessage::assistant(
|
||||||
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
format!("I reached the maximum number of tool call iterations ({}) without completing the task. The work done so far has been lost due to an error. Please try breaking the task into smaller steps.", self.max_iterations)
|
||||||
);
|
);
|
||||||
Ok(final_message)
|
emitted_messages.push(final_message.clone());
|
||||||
|
Ok(AgentProcessResult {
|
||||||
|
final_response: final_message,
|
||||||
|
emitted_messages,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -593,6 +619,25 @@ mod tests {
|
|||||||
// If there's only 1 tool, should return false regardless
|
// If there's only 1 tool, should return false regardless
|
||||||
assert_eq!(calls.len() <= 1, true);
|
assert_eq!(calls.len() <= 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_chat_message_to_llm_message_preserves_assistant_tool_calls() {
|
||||||
|
let chat_message = ChatMessage::assistant_with_tool_calls(
|
||||||
|
"calling tool",
|
||||||
|
vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: serde_json::json!({ "expression": "2+2" }),
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
|
||||||
|
let provider_message = chat_message_to_llm_message(&chat_message);
|
||||||
|
|
||||||
|
assert_eq!(provider_message.role, "assistant");
|
||||||
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap().len(), 1);
|
||||||
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||||
|
assert_eq!(provider_message.tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
pub mod agent_loop;
|
pub mod agent_loop;
|
||||||
pub mod context_compressor;
|
pub mod context_compressor;
|
||||||
|
|
||||||
pub use agent_loop::{AgentLoop, AgentError};
|
pub use agent_loop::{AgentLoop, AgentError, AgentProcessResult};
|
||||||
pub use context_compressor::ContextCompressor;
|
pub use context_compressor::ContextCompressor;
|
||||||
|
|||||||
@ -1,6 +1,8 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
// ContentBlock - Multimodal content representation (OpenAI-style)
|
// ContentBlock - Multimodal content representation (OpenAI-style)
|
||||||
// ============================================================================
|
// ============================================================================
|
||||||
@ -69,6 +71,8 @@ pub struct ChatMessage {
|
|||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_name: Option<String>,
|
pub tool_name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatMessage {
|
impl ChatMessage {
|
||||||
@ -81,6 +85,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,6 +98,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +111,20 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant_with_tool_calls(content: impl Into<String>, tool_calls: Vec<ToolCall>) -> Self {
|
||||||
|
Self {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: content.into(),
|
||||||
|
media_refs: Vec::new(),
|
||||||
|
timestamp: current_timestamp(),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_name: None,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -117,6 +137,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
tool_name: None,
|
tool_name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -129,6 +150,7 @@ impl ChatMessage {
|
|||||||
timestamp: current_timestamp(),
|
timestamp: current_timestamp(),
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
tool_call_id: Some(tool_call_id.into()),
|
||||||
tool_name: Some(tool_name.into()),
|
tool_name: Some(tool_name.into()),
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -115,6 +115,16 @@ impl Session {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn append_persisted_messages<I>(&mut self, chat_id: &str, messages: I) -> Result<(), AgentError>
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = ChatMessage>,
|
||||||
|
{
|
||||||
|
for message in messages {
|
||||||
|
self.append_persisted_message(chat_id, message)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
pub fn create_user_message(&self, content: &str, media_refs: Vec<String>) -> ChatMessage {
|
||||||
if media_refs.is_empty() {
|
if media_refs.is_empty() {
|
||||||
ChatMessage::user(content)
|
ChatMessage::user(content)
|
||||||
@ -381,12 +391,12 @@ impl SessionManager {
|
|||||||
|
|
||||||
// 创建 agent 并处理
|
// 创建 agent 并处理
|
||||||
let agent = session_guard.create_agent()?;
|
let agent = session_guard.create_agent()?;
|
||||||
let response = agent.process(history).await?;
|
let result = agent.process(history).await?;
|
||||||
|
|
||||||
// 添加助手响应到历史
|
// 按真实顺序持久化 assistant tool_calls、tool 结果和最终 assistant 回复
|
||||||
session_guard.append_persisted_message(chat_id, response.clone())?;
|
session_guard.append_persisted_messages(chat_id, result.emitted_messages.clone())?;
|
||||||
|
|
||||||
response
|
result.final_response
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
|
|||||||
@ -171,13 +171,13 @@ async fn handle_inbound(
|
|||||||
|
|
||||||
let agent = session_guard.create_agent()?;
|
let agent = session_guard.create_agent()?;
|
||||||
match agent.process(history).await {
|
match agent.process(history).await {
|
||||||
Ok(response) => {
|
Ok(result) => {
|
||||||
session_guard.append_persisted_message(&chat_id, response.clone())?;
|
session_guard.append_persisted_messages(&chat_id, result.emitted_messages.clone())?;
|
||||||
let _ = session_guard
|
let _ = session_guard
|
||||||
.send(WsOutbound::AssistantResponse {
|
.send(WsOutbound::AssistantResponse {
|
||||||
id: response.id,
|
id: result.final_response.id,
|
||||||
content: response.content,
|
content: result.final_response.content,
|
||||||
role: response.role,
|
role: result.final_response.role,
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,6 +57,54 @@ impl OpenAIProvider {
|
|||||||
model_extra,
|
model_extra,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
|
let mut body = json!({
|
||||||
|
"model": self.model_id,
|
||||||
|
"messages": request.messages.iter().map(|m| {
|
||||||
|
if m.role == "tool" {
|
||||||
|
json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": convert_content_blocks(&m.content),
|
||||||
|
"tool_call_id": m.tool_call_id,
|
||||||
|
"name": m.name,
|
||||||
|
})
|
||||||
|
} else if m.role == "assistant" && m.tool_calls.is_some() {
|
||||||
|
json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": convert_content_blocks(&m.content),
|
||||||
|
"tool_calls": m.tool_calls.as_ref().map(|calls| {
|
||||||
|
calls.iter().map(|call| json!({
|
||||||
|
"id": call.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": call.name,
|
||||||
|
"arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string())
|
||||||
|
}
|
||||||
|
})).collect::<Vec<_>>()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
json!({
|
||||||
|
"role": m.role,
|
||||||
|
"content": convert_content_blocks(&m.content)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}).collect::<Vec<_>>(),
|
||||||
|
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
||||||
|
"max_tokens": request.max_tokens.or(self.max_tokens),
|
||||||
|
});
|
||||||
|
|
||||||
|
for (key, value) in &self.model_extra {
|
||||||
|
body[key] = value.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(tools) = &request.tools {
|
||||||
|
body["tools"] = json!(tools);
|
||||||
|
}
|
||||||
|
|
||||||
|
body
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
@ -116,35 +164,7 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
) -> Result<ChatCompletionResponse, Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let url = format!("{}/chat/completions", self.base_url);
|
let url = format!("{}/chat/completions", self.base_url);
|
||||||
|
|
||||||
let mut body = json!({
|
let body = self.build_request_body(&request);
|
||||||
"model": self.model_id,
|
|
||||||
"messages": request.messages.iter().map(|m| {
|
|
||||||
if m.role == "tool" {
|
|
||||||
json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": convert_content_blocks(&m.content),
|
|
||||||
"tool_call_id": m.tool_call_id,
|
|
||||||
"name": m.name,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
json!({
|
|
||||||
"role": m.role,
|
|
||||||
"content": convert_content_blocks(&m.content)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}).collect::<Vec<_>>(),
|
|
||||||
"temperature": request.temperature.or(self.temperature).unwrap_or(0.7),
|
|
||||||
"max_tokens": request.max_tokens.or(self.max_tokens),
|
|
||||||
});
|
|
||||||
|
|
||||||
// Add model extra fields
|
|
||||||
for (key, value) in &self.model_extra {
|
|
||||||
body[key] = value.clone();
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(tools) = &request.tools {
|
|
||||||
body["tools"] = json!(tools);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Debug: Log LLM request summary (only in debug builds)
|
// Debug: Log LLM request summary (only in debug builds)
|
||||||
#[cfg(debug_assertions)]
|
#[cfg(debug_assertions)]
|
||||||
@ -242,3 +262,50 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
&self.model_id
|
&self.model_id
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::providers::Message;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_body_includes_assistant_tool_calls() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::new(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![ContentBlock::text("calling tool")],
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: Some(vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: json!({"expression": "1+1"}),
|
||||||
|
}]),
|
||||||
|
}],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = provider.build_request_body(&request);
|
||||||
|
let messages = body["messages"].as_array().unwrap();
|
||||||
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tool_calls.len(), 1);
|
||||||
|
assert_eq!(tool_calls[0]["id"], "call_1");
|
||||||
|
assert_eq!(tool_calls[0]["type"], "function");
|
||||||
|
assert_eq!(tool_calls[0]["function"]["name"], "calculator");
|
||||||
|
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -10,6 +10,8 @@ pub struct Message {
|
|||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub name: Option<String>,
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Message {
|
impl Message {
|
||||||
@ -19,6 +21,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -28,6 +31,7 @@ impl Message {
|
|||||||
content,
|
content,
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,6 +41,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -46,6 +51,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
name: None,
|
name: None,
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,6 +61,7 @@ impl Message {
|
|||||||
content: vec![ContentBlock::text(content)],
|
content: vec![ContentBlock::text(content)],
|
||||||
tool_call_id: Some(tool_call_id.into()),
|
tool_call_id: Some(tool_call_id.into()),
|
||||||
name: Some(tool_name.into()),
|
name: Some(tool_name.into()),
|
||||||
|
tool_calls: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -85,6 +85,7 @@ impl SessionStore {
|
|||||||
media_refs_json TEXT NOT NULL,
|
media_refs_json TEXT NOT NULL,
|
||||||
tool_call_id TEXT,
|
tool_call_id TEXT,
|
||||||
tool_name TEXT,
|
tool_name TEXT,
|
||||||
|
tool_calls_json TEXT,
|
||||||
created_at INTEGER NOT NULL,
|
created_at INTEGER NOT NULL,
|
||||||
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
FOREIGN KEY(session_id) REFERENCES sessions(id) ON DELETE CASCADE,
|
||||||
UNIQUE(session_id, seq)
|
UNIQUE(session_id, seq)
|
||||||
@ -97,6 +98,10 @@ impl SessionStore {
|
|||||||
",
|
",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
if !table_has_column(&conn, "messages", "tool_calls_json")? {
|
||||||
|
conn.execute("ALTER TABLE messages ADD COLUMN tool_calls_json TEXT", [])?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Arc::new(Mutex::new(conn)),
|
conn: Arc::new(Mutex::new(conn)),
|
||||||
})
|
})
|
||||||
@ -260,12 +265,13 @@ impl SessionStore {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
let media_refs_json = serde_json::to_string(&message.media_refs)?;
|
||||||
|
let tool_calls_json = message.tool_calls.as_ref().map(serde_json::to_string).transpose()?;
|
||||||
tx.execute(
|
tx.execute(
|
||||||
"
|
"
|
||||||
INSERT INTO messages (
|
INSERT INTO messages (
|
||||||
id, session_id, seq, role, content,
|
id, session_id, seq, role, content,
|
||||||
media_refs_json, tool_call_id, tool_name, created_at
|
media_refs_json, tool_call_id, tool_name, tool_calls_json, created_at
|
||||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)
|
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)
|
||||||
",
|
",
|
||||||
params![
|
params![
|
||||||
message.id,
|
message.id,
|
||||||
@ -276,6 +282,7 @@ impl SessionStore {
|
|||||||
media_refs_json,
|
media_refs_json,
|
||||||
message.tool_call_id,
|
message.tool_call_id,
|
||||||
message.tool_name,
|
message.tool_name,
|
||||||
|
tool_calls_json,
|
||||||
message.timestamp,
|
message.timestamp,
|
||||||
],
|
],
|
||||||
)?;
|
)?;
|
||||||
@ -301,7 +308,7 @@ impl SessionStore {
|
|||||||
let conn = self.conn.lock().expect("session db mutex poisoned");
|
let conn = self.conn.lock().expect("session db mutex poisoned");
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"
|
"
|
||||||
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name
|
SELECT id, role, content, media_refs_json, created_at, tool_call_id, tool_name, tool_calls_json
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE session_id = ?1
|
WHERE session_id = ?1
|
||||||
ORDER BY seq ASC
|
ORDER BY seq ASC
|
||||||
@ -318,6 +325,19 @@ impl SessionStore {
|
|||||||
)
|
)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
let tool_calls_json: Option<String> = row.get(7)?;
|
||||||
|
let tool_calls = tool_calls_json
|
||||||
|
.as_deref()
|
||||||
|
.map(serde_json::from_str)
|
||||||
|
.transpose()
|
||||||
|
.map_err(|err| {
|
||||||
|
rusqlite::Error::FromSqlConversionFailure(
|
||||||
|
7,
|
||||||
|
rusqlite::types::Type::Text,
|
||||||
|
Box::new(err),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
Ok(ChatMessage {
|
Ok(ChatMessage {
|
||||||
id: row.get(0)?,
|
id: row.get(0)?,
|
||||||
role: row.get(1)?,
|
role: row.get(1)?,
|
||||||
@ -326,6 +346,7 @@ impl SessionStore {
|
|||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
tool_call_id: row.get(5)?,
|
tool_call_id: row.get(5)?,
|
||||||
tool_name: row.get(6)?,
|
tool_name: row.get(6)?,
|
||||||
|
tool_calls,
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
@ -376,6 +397,7 @@ fn current_timestamp() -> i64 {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::providers::ToolCall;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_persistent_session_id_for_cli_and_channel() {
|
fn test_persistent_session_id_for_cli_and_channel() {
|
||||||
@ -444,4 +466,60 @@ mod tests {
|
|||||||
assert_eq!(first.chat_id, "chat-1");
|
assert_eq!(first.chat_id, "chat-1");
|
||||||
assert_eq!(second.channel_name, "feishu");
|
assert_eq!(second.channel_name, "feishu");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_assistant_tool_calls_roundtrip() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
let session = store.create_cli_session(Some("tools")).unwrap();
|
||||||
|
|
||||||
|
let assistant = ChatMessage::assistant_with_tool_calls(
|
||||||
|
"calling tool",
|
||||||
|
vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: serde_json::json!({ "expression": "3*7" }),
|
||||||
|
}],
|
||||||
|
);
|
||||||
|
|
||||||
|
store.append_message(&session.id, &assistant).unwrap();
|
||||||
|
|
||||||
|
let messages = store.load_messages(&session.id).unwrap();
|
||||||
|
assert_eq!(messages.len(), 1);
|
||||||
|
assert_eq!(messages[0].role, "assistant");
|
||||||
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap().len(), 1);
|
||||||
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].id, "call_1");
|
||||||
|
assert_eq!(messages[0].tool_calls.as_ref().unwrap()[0].name, "calculator");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tool_result_roundtrip() {
|
||||||
|
let store = SessionStore::in_memory().unwrap();
|
||||||
|
let session = store.create_cli_session(Some("tool-result")).unwrap();
|
||||||
|
|
||||||
|
let tool_message = ChatMessage::tool("call_9", "file_write", "saved to /tmp/output.txt");
|
||||||
|
store.append_message(&session.id, &tool_message).unwrap();
|
||||||
|
|
||||||
|
let messages = store.load_messages(&session.id).unwrap();
|
||||||
|
assert_eq!(messages.len(), 1);
|
||||||
|
assert_eq!(messages[0].role, "tool");
|
||||||
|
assert_eq!(messages[0].content, "saved to /tmp/output.txt");
|
||||||
|
assert_eq!(messages[0].tool_call_id.as_deref(), Some("call_9"));
|
||||||
|
assert_eq!(messages[0].tool_name.as_deref(), Some("file_write"));
|
||||||
|
assert!(messages[0].tool_calls.is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn table_has_column(conn: &Connection, table: &str, column: &str) -> Result<bool, rusqlite::Error> {
|
||||||
|
let pragma = format!("PRAGMA table_info({})", table);
|
||||||
|
let mut stmt = conn.prepare(&pragma)?;
|
||||||
|
let mut rows = stmt.query([])?;
|
||||||
|
|
||||||
|
while let Some(row) = rows.next()? {
|
||||||
|
let name: String = row.get(1)?;
|
||||||
|
if name == column {
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(false)
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user