diff --git a/router/src/lib.rs b/router/src/lib.rs index f5ddcc56..4f9637d3 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -1616,4 +1616,36 @@ mod tests { r#"{"role":"assistant","tool_calls":[{"id":"0","type":"function","function":{"description":null,"name":"myfn","arguments":{"format":"csv"}}}]}"# ); } + + #[test] + fn tool_choice_formats() { + + #[derive(Deserialize)] + struct TestRequest { + tool_choice: ToolChoice, + } + + let none = r#"{"tool_choice":"none"}"#; + let de_none: TestRequest = serde_json::from_str(none).unwrap(); + assert_eq!(de_none.tool_choice, ToolChoice(Some(ToolType::NoTool))); + + let auto = r#"{"tool_choice":"auto"}"#; + let de_auto: TestRequest = serde_json::from_str(auto).unwrap(); + assert_eq!(de_auto.tool_choice, ToolChoice(Some(ToolType::OneOf))); + + let ref_choice = ToolChoice(Some(ToolType::Function(FunctionName { name: "myfn".to_string() }))); + + let named = r#"{"tool_choice":"myfn"}"#; + let de_named: TestRequest = serde_json::from_str(named).unwrap(); + assert_eq!(de_named.tool_choice, ref_choice); + + let old_named = r#"{"tool_choice":{"function":{"name":"myfn"}}}"#; + let de_old_named: TestRequest = serde_json::from_str(old_named).unwrap(); + assert_eq!(de_old_named.tool_choice, ref_choice); + + let openai_named = r#"{"tool_choice":{"type":"function","function":{"name":"myfn"}}}"#; + let de_openai_named: TestRequest = serde_json::from_str(openai_named).unwrap(); + + assert_eq!(de_openai_named.tool_choice, ref_choice); + } }