from diffusers.utils import is_flax_available from diffusers.utils.testing_utils import require_flax if is_flax_available(): import jax @require_flax class FlaxModelTesterMixin: def test_output(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) jax.lax.stop_gradient(variables) output = model.apply(variables, inputs_dict["sample"]) if isinstance(output, dict): output = output.sample self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") def test_forward_with_norm_groups(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict["norm_num_groups"] = 16 init_dict["block_out_channels"] = (16, 32) model = self.model_class(**init_dict) variables = model.init(inputs_dict["prng_key"], inputs_dict["sample"]) jax.lax.stop_gradient(variables) output = model.apply(variables, inputs_dict["sample"]) if isinstance(output, dict): output = output.sample self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")