From c83d697f93da67ec50f7b113f257756aa7754094 Mon Sep 17 00:00:00 2001 From: ooodc <549496103@qq.com> Date: Sun, 26 Apr 2026 23:05:34 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=BC=BAOpenAIProvider?= =?UTF-8?q?=E4=BB=A5=E6=94=AF=E6=8C=81JSON=E5=B7=A5=E5=85=B7=E5=8F=82?= =?UTF-8?q?=E6=95=B0=EF=BC=8C=E4=BC=98=E5=8C=96=E8=B0=83=E5=BA=A6=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E5=B7=A5=E5=85=B7=E7=9A=84=E5=8F=82=E6=95=B0=E9=AA=8C?= =?UTF-8?q?=E8=AF=81=E5=92=8C=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/providers/openai.rs | 110 +++++++++++++++++++++++++++++++++- src/tools/scheduler_manage.rs | 99 +++++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 6 deletions(-) diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 5c74cf3..3549e91 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -36,6 +36,13 @@ pub struct OpenAIProvider { model_extra: HashMap, } +#[derive(Deserialize)] +#[serde(untagged)] +enum OAIFunctionArguments { + Json(Value), + String(String), +} + impl OpenAIProvider { pub fn new( name: String, @@ -67,6 +74,23 @@ impl OpenAIProvider { } } + fn uses_json_tool_arguments(&self) -> bool { + self.model_extra + .get("tool_call_arguments_json") + .and_then(|value| value.as_bool()) + .unwrap_or(false) + } + + fn serialize_tool_arguments(&self, arguments: &Value) -> Value { + if self.uses_json_tool_arguments() { + arguments.clone() + } else { + Value::String( + serde_json::to_string(arguments).unwrap_or_else(|_| "null".to_string()), + ) + } + } + fn build_request_body(&self, request: &ChatCompletionRequest) -> Value { let mut body = json!({ "model": self.model_id, @@ -88,7 +112,7 @@ impl OpenAIProvider { "type": "function", "function": { "name": call.name, - "arguments": serde_json::to_string(&call.arguments).unwrap_or_else(|_| "null".to_string()) + "arguments": self.serialize_tool_arguments(&call.arguments) } })).collect::>() }) @@ -170,7 +194,7 @@ struct OpenAIToolCall { #[derive(Deserialize)] struct OAIFunction { name: String, - arguments: String, + arguments: OAIFunctionArguments, } #[derive(Deserialize, Default)] @@ -260,7 +284,12 @@ impl LLMProvider for OpenAIProvider { .map(|tc| ToolCall { id: tc.id.clone(), name: tc.function.name.clone(), - arguments: serde_json::from_str(&tc.function.arguments).unwrap_or(serde_json::Value::Null), + arguments: match &tc.function.arguments { + OAIFunctionArguments::Json(arguments) => arguments.clone(), + OAIFunctionArguments::String(arguments) => { + serde_json::from_str(arguments).unwrap_or(serde_json::Value::Null) + } + }, }) .collect(); @@ -339,6 +368,48 @@ mod tests { assert_eq!(tool_calls[0]["function"]["arguments"], "{\"expression\":\"1+1\"}"); } + #[test] + fn test_build_request_body_uses_json_tool_arguments_when_enabled() { + let provider = OpenAIProvider::new( + "test".to_string(), + "key".to_string(), + "https://example.com/v1".to_string(), + HashMap::new(), + 120, + "gpt-test".to_string(), + None, + None, + HashMap::from([( + "tool_call_arguments_json".to_string(), + Value::Bool(true), + )]), + ); + + let request = ChatCompletionRequest { + messages: vec![Message { + role: "assistant".to_string(), + content: vec![ContentBlock::text("calling tool")], + reasoning_content: None, + tool_call_id: None, + name: None, + tool_calls: Some(vec![ToolCall { + id: "call_1".to_string(), + name: "calculator".to_string(), + arguments: json!({"expression": "1+1"}), + }]), + }], + temperature: None, + max_tokens: None, + tools: None, + }; + + let body = provider.build_request_body(&request); + let messages = body["messages"].as_array().unwrap(); + let tool_calls = messages[0]["tool_calls"].as_array().unwrap(); + + assert_eq!(tool_calls[0]["function"]["arguments"], json!({"expression": "1+1"})); + } + #[test] fn test_build_request_body_includes_assistant_reasoning_content() { let provider = OpenAIProvider::new( @@ -395,4 +466,37 @@ mod tests { assert_eq!(response.choices[0].message.reasoning_content.as_deref(), Some("hidden reasoning")); } + + #[test] + fn test_openai_response_parses_json_tool_arguments() { + let response: OpenAIResponse = serde_json::from_value(json!({ + "id": "resp_1", + "model": "gpt-test", + "choices": [{ + "message": { + "content": "", + "tool_calls": [{ + "id": "call_1", + "function": { + "name": "scheduler_manage", + "arguments": {"action": "list"} + } + }] + } + }], + "usage": { + "prompt_tokens": 1, + "completion_tokens": 1, + "total_tokens": 2 + } + })) + .unwrap(); + + match &response.choices[0].message.tool_calls[0].function.arguments { + OAIFunctionArguments::Json(arguments) => { + assert_eq!(arguments, &json!({"action": "list"})); + } + OAIFunctionArguments::String(_) => panic!("expected JSON tool arguments"), + } + } } diff --git a/src/tools/scheduler_manage.rs b/src/tools/scheduler_manage.rs index d769efd..b7fc57f 100644 --- a/src/tools/scheduler_manage.rs +++ b/src/tools/scheduler_manage.rs @@ -35,6 +35,21 @@ impl Tool for SchedulerManageTool { } fn parameters_schema(&self) -> serde_json::Value { + let mut allowed_agents = self + .known_agents + .iter() + .cloned() + .collect::>(); + allowed_agents.sort(); + let agent_hint = if allowed_agents.is_empty() { + "agent_task payload.agent may be omitted or set to 'default'.".to_string() + } else { + format!( + "agent_task payload.agent may be omitted, set to 'default', or use one of configured agents: {}.", + allowed_agents.join(", ") + ) + }; + json!({ "type": "object", "properties": { @@ -63,7 +78,7 @@ impl Tool for SchedulerManageTool { }, "payload": { "type": "object", - "description": "Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. outbound_message expects content. internal_event expects event." + "description": format!("Job payload. agent_task supports prompt, agent, fresh_session, system_prompt, sender_id, metadata. {} outbound_message expects content. internal_event expects event.", agent_hint) }, "max_runs": { "type": ["integer", "null"] @@ -83,6 +98,18 @@ impl Tool for SchedulerManageTool { context: &crate::tools::ToolContext, args: serde_json::Value, ) -> anyhow::Result { + if args.is_null() { + return Ok(error_result( + "Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}", + )); + } + + if !args.is_object() { + return Ok(error_result( + "Invalid parameters: scheduler_manage expects a JSON object", + )); + } + let action = match args.get("action").and_then(|value| value.as_str()) { Some(action) => action, None => return Ok(error_result("Missing required parameter: action")), @@ -263,7 +290,28 @@ fn validate_agent_task_payload(payload: &serde_json::Value, known_agents: &HashS return Ok(()); } - anyhow::bail!("Unknown agent '{}' for agent_task payload.agent", normalized) + anyhow::bail!(unknown_agent_message(normalized, known_agents)) +} + +fn unknown_agent_message(agent_name: &str, known_agents: &HashSet) -> String { + let mut configured_agents = known_agents.iter().cloned().collect::>(); + configured_agents.sort(); + + let configured_hint = if configured_agents.is_empty() { + "No named agents are configured; use payload.agent='default' or omit payload.agent.".to_string() + } else { + format!( + "payload.agent must be omitted, set to 'default', or use one of configured agents: default, {}.", + configured_agents.join(", ") + ) + }; + + format!( + "Unknown agent '{}' for agent_task payload.agent. {} '{}' is not an agent. If you mean a skill, do not put it in payload.agent.", + agent_name, + configured_hint, + agent_name, + ) } fn validate_outbound_message_payload(payload: &serde_json::Value) -> anyhow::Result<()> { @@ -524,6 +572,51 @@ mod tests { assert!(put_result.is_err()); let error = put_result.err().unwrap().to_string(); - assert!(error.contains("Unknown agent 'missing-agent'")); + assert!(error.contains("Unknown agent 'missing-agent' for agent_task payload.agent")); + assert!(error.contains("payload.agent must be omitted, set to 'default', or use one of configured agents: default, planner")); + assert!(error.contains("If you mean a skill, do not put it in payload.agent")); + } + + #[tokio::test] + async fn test_scheduler_manage_accepts_default_agent() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store, HashSet::from(["planner".to_string()])); + + let put_result = tool + .execute(json!({ + "action": "put", + "id": "agent.default_summary", + "kind": "agent_task", + "schedule": { + "type": "cron", + "expression": "0 9 * * *" + }, + "target": { + "channel": "feishu", + "chat_id": "oc_demo" + }, + "payload": { + "prompt": "请总结今天待办", + "agent": "default" + } + })) + .await + .unwrap(); + + assert!(put_result.success); + } + + #[tokio::test] + async fn test_scheduler_manage_rejects_null_args_locally() { + let store = Arc::new(SessionStore::in_memory().unwrap()); + let tool = SchedulerManageTool::new(store, HashSet::new()); + + let result = tool.execute(serde_json::Value::Null).await.unwrap(); + + assert!(!result.success); + assert_eq!( + result.error.as_deref(), + Some("Missing required parameters: scheduler_manage expects a JSON object like {\"action\":\"list\"}") + ); } } \ No newline at end of file