feat: 增强OpenAIProvider以支持JSON工具参数,优化调度管理工具的参数验证和错误处理
This commit is contained in:
parent
88e1bfd9f2
commit
c83d697f93
@ -36,6 +36,13 @@ pub struct OpenAIProvider {
|
|||||||
model_extra: HashMap<String, serde_json::Value>,
|
model_extra: HashMap<String, serde_json::Value>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
enum OAIFunctionArguments {
|
||||||
|
Json(Value),
|
||||||
|
String(String),
|
||||||
|
}
|
||||||
|
|
||||||
impl OpenAIProvider {
|
impl OpenAIProvider {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
name: String,
|
name: String,
|
||||||
@ -67,6 +74,23 @@ impl OpenAIProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn uses_json_tool_arguments(&self) -> bool {
|
||||||
|
self.model_extra
|
||||||
|
.get("tool_call_arguments_json")
|
||||||
|
.and_then(|value| value.as_bool())
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn serialize_tool_arguments(&self, arguments: &Value) -> Value {
|
||||||
|
if self.uses_json_tool_arguments() {
|
||||||
|
arguments.clone()
|
||||||
|
} else {
|
||||||
|
Value::String(
|
||||||
|
serde_json::to_string(arguments).unwrap_or_else(|_| "null".to_string()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
fn build_request_body(&self, request: &ChatCompletionRequest) -> Value {
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
@ -88,7 +112,7 @@ impl OpenAIProvider {
|
|||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": call.name,
|
"name": call.name,
|
||||||
"arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string())
|
"arguments": self.serialize_tool_arguments(&call.arguments)
|
||||||
}
|
}
|
||||||
})).collect::<Vec<_>>()
|
})).collect::<Vec<_>>()
|
||||||
})
|
})
|
||||||
@ -170,7 +194,7 @@ struct OpenAIToolCall {
|
|||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct OAIFunction {
|
struct OAIFunction {
|
||||||
name: String,
|
name: String,
|
||||||
arguments: String,
|
arguments: OAIFunctionArguments,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize, Default)]
|
#[derive(Deserialize, Default)]
|
||||||
@ -260,7 +284,12 @@ impl LLMProvider for OpenAIProvider {
|
|||||||
.map(|tc| ToolCall {
|
.map(|tc| ToolCall {
|
||||||
id: tc.id.clone(),
|
id: tc.id.clone(),
|
||||||
name: tc.function.name.clone(),
|
name: tc.function.name.clone(),
|
||||||
arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null),
|
arguments: match &tc.function.arguments {
|
||||||
|
OAIFunctionArguments::Json(arguments) => arguments.clone(),
|
||||||
|
OAIFunctionArguments::String(arguments) => {
|
||||||
|
serde_json::from_str(arguments).unwrap_or(serde_json::Value::Null)
|
||||||
|
}
|
||||||
|
},
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
@ -339,6 +368,48 @@ mod tests {
|
|||||||
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_build_request_body_uses_json_tool_arguments_when_enabled() {
|
||||||
|
let provider = OpenAIProvider::new(
|
||||||
|
"test".to_string(),
|
||||||
|
"key".to_string(),
|
||||||
|
"https://example.com/v1".to_string(),
|
||||||
|
HashMap::new(),
|
||||||
|
120,
|
||||||
|
"gpt-test".to_string(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
HashMap::from([(
|
||||||
|
"tool_call_arguments_json".to_string(),
|
||||||
|
Value::Bool(true),
|
||||||
|
)]),
|
||||||
|
);
|
||||||
|
|
||||||
|
let request = ChatCompletionRequest {
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![ContentBlock::text("calling tool")],
|
||||||
|
reasoning_content: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
tool_calls: Some(vec![ToolCall {
|
||||||
|
id: "call_1".to_string(),
|
||||||
|
name: "calculator".to_string(),
|
||||||
|
arguments: json!({"expression": "1+1"}),
|
||||||
|
}]),
|
||||||
|
}],
|
||||||
|
temperature: None,
|
||||||
|
max_tokens: None,
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let body = provider.build_request_body(&request);
|
||||||
|
let messages = body["messages"].as_array().unwrap();
|
||||||
|
let tool_calls = messages[0]["tool_calls"].as_array().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"}));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_build_request_body_includes_assistant_reasoning_content() {
|
fn test_build_request_body_includes_assistant_reasoning_content() {
|
||||||
let provider = OpenAIProvider::new(
|
let provider = OpenAIProvider::new(
|
||||||
@ -395,4 +466,37 @@ mod tests {
|
|||||||
|
|
||||||
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
|
assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_openai_response_parses_json_tool_arguments() {
|
||||||
|
let response: OpenAIResponse = serde_json::from_value(json!({
|
||||||
|
"id": "resp_1",
|
||||||
|
"model": "gpt-test",
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": "call_1",
|
||||||
|
"function": {
|
||||||
|
"name": "scheduler_manage",
|
||||||
|
"arguments": {"action": "list"}
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 1,
|
||||||
|
"completion_tokens": 1,
|
||||||
|
"total_tokens": 2
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
match &response.choices[0].message.tool_calls[0].function.arguments {
|
||||||
|
OAIFunctionArguments::Json(arguments) => {
|
||||||
|
assert_eq!(arguments, &json!({"action": "list"}));
|
||||||
|
}
|
||||||
|
OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,6 +35,21 @@ impl Tool for SchedulerManageTool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn parameters_schema(&self) -> serde_json::Value {
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
let mut allowed_agents = self
|
||||||
|
.known_agents
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
allowed_agents.sort();
|
||||||
|
let agent_hint = if allowed_agents.is_empty() {
|
||||||
|
"agent_task payload.agent may be omitted or set to 'default'.".to_string()
|
||||||
|
} else {
|
||||||
|
format!(
|
||||||
|
"agent_task payload.agent may be omitted, set to 'default', or use one of configured agents: {}.",
|
||||||
|
allowed_agents.join(", ")
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
json!({
|
json!({
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
@ -63,7 +78,7 @@ impl Tool for SchedulerManageTool {
|
|||||||
},
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"description": "Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event."
|
"description": format!("Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. {} outbound_message expects content. internal_event expects event.", agent_hint)
|
||||||
},
|
},
|
||||||
"max_runs": {
|
"max_runs": {
|
||||||
"type": ["integer", "null"]
|
"type": ["integer", "null"]
|
||||||
@ -83,6 +98,18 @@ impl Tool for SchedulerManageTool {
|
|||||||
context: &crate::tools::ToolContext,
|
context: &crate::tools::ToolContext,
|
||||||
args: serde_json::Value,
|
args: serde_json::Value,
|
||||||
) -> anyhow::Result<ToolResult> {
|
) -> anyhow::Result<ToolResult> {
|
||||||
|
if args.is_null() {
|
||||||
|
return Ok(error_result(
|
||||||
|
"Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !args.is_object() {
|
||||||
|
return Ok(error_result(
|
||||||
|
"Invalid parameters: scheduler_manage expects a JSON object",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let action = match args.get("action").and_then(|value| value.as_str()) {
|
let action = match args.get("action").and_then(|value| value.as_str()) {
|
||||||
Some(action) => action,
|
Some(action) => action,
|
||||||
None => return Ok(error_result("Missing required parameter: action")),
|
None => return Ok(error_result("Missing required parameter: action")),
|
||||||
@ -263,7 +290,28 @@ fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashS
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::bail!("Unknown agent '{}' for agent_task payload.agent", normalized)
|
anyhow::bail!(unknown_agent_message(normalized, known_agents))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unknown_agent_message(agent_name: &str, known_agents: &HashSet<String>) -> String {
|
||||||
|
let mut configured_agents = known_agents.iter().cloned().collect::<Vec<_>>();
|
||||||
|
configured_agents.sort();
|
||||||
|
|
||||||
|
let configured_hint = if configured_agents.is_empty() {
|
||||||
|
"No named agents are configured; use payload.agent='default' or omit payload.agent.".to_string()
|
||||||
|
} else {
|
||||||
|
format!(
|
||||||
|
"payload.agent must be omitted, set to 'default', or use one of configured agents: default, {}.",
|
||||||
|
configured_agents.join(", ")
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
format!(
|
||||||
|
"Unknown agent '{}' for agent_task payload.agent. {} '{}' is not an agent. If you mean a skill, do not put it in payload.agent.",
|
||||||
|
agent_name,
|
||||||
|
configured_hint,
|
||||||
|
agent_name,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn validate_outbound_message_payload(payload: &serde_json::Value) -> anyhow::Result<()> {
|
fn validate_outbound_message_payload(payload: &serde_json::Value) -> anyhow::Result<()> {
|
||||||
@ -524,6 +572,51 @@ mod tests {
|
|||||||
|
|
||||||
assert!(put_result.is_err());
|
assert!(put_result.is_err());
|
||||||
let error = put_result.err().unwrap().to_string();
|
let error = put_result.err().unwrap().to_string();
|
||||||
assert!(error.contains("Unknown agent 'missing-agent'"));
|
assert!(error.contains("Unknown agent 'missing-agent' for agent_task payload.agent"));
|
||||||
|
assert!(error.contains("payload.agent must be omitted, set to 'default', or use one of configured agents: default, planner"));
|
||||||
|
assert!(error.contains("If you mean a skill, do not put it in payload.agent"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_scheduler_manage_accepts_default_agent() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()]));
|
||||||
|
|
||||||
|
let put_result = tool
|
||||||
|
.execute(json!({
|
||||||
|
"action": "put",
|
||||||
|
"id": "agent.default_summary",
|
||||||
|
"kind": "agent_task",
|
||||||
|
"schedule": {
|
||||||
|
"type": "cron",
|
||||||
|
"expression": "0 9 * * *"
|
||||||
|
},
|
||||||
|
"target": {
|
||||||
|
"channel": "feishu",
|
||||||
|
"chat_id": "oc_demo"
|
||||||
|
},
|
||||||
|
"payload": {
|
||||||
|
"prompt": "请总结今天待办",
|
||||||
|
"agent": "default"
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(put_result.success);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_scheduler_manage_rejects_null_args_locally() {
|
||||||
|
let store = Arc::new(SessionStore::in_memory().unwrap());
|
||||||
|
let tool = SchedulerManageTool::new(store, HashSet::new());
|
||||||
|
|
||||||
|
let result = tool.execute(serde_json::Value::Null).await.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert_eq!(
|
||||||
|
result.error.as_deref(),
|
||||||
|
Some("Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user