diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 2ecab7f5..a58759b2 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -31,6 +31,7 @@ from diffusers import ( DDIMScheduler, DDPMScheduler, GLIDESuperResUNetModel, + GLIDETextToImageUNetModel, LatentDiffusion, PNDMScheduler, UNetGradTTSModel, @@ -261,8 +262,6 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): sizes = (32, 32) low_res_size = (4, 4) - torch_device = "cpu" - noise = torch.randn((batch_size, num_channels // 2) + sizes).to(torch_device) low_res = torch.randn((batch_size, 3) + low_res_size).to(torch_device) time_step = torch.tensor([10] * noise.shape[0], device=torch_device) @@ -343,6 +342,100 @@ class GLIDESuperResUNetTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) +class GLIDETextToImageUNetModelTests(ModelTesterMixin, unittest.TestCase): + model_class = GLIDETextToImageUNetModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 3 + sizes = (32, 32) + transformer_dim = 32 + seq_len = 16 + + noise = torch.randn((batch_size, num_channels) + sizes).to(torch_device) + emb = torch.randn((batch_size, seq_len, transformer_dim)).to(torch_device) + time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + + return {"x": noise, "timesteps": time_step, "transformer_out": emb} + + @property + def get_input_shape(self): + return (3, 32, 32) + + @property + def get_output_shape(self): + return (6, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "attention_resolutions": (2,), + "channel_mult": (1, 2), + "in_channels": 3, + "out_channels": 6, + "model_channels": 32, + "num_head_channels": 8, + "num_heads_upsample": 1, + "num_res_blocks": 2, + "resblock_updown": True, + "resolution": 32, + "use_scale_shift_norm": True, + "transformer_dim": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_output(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + output, _ = torch.split(output, 3, dim=1) + + self.assertIsNotNone(output) + expected_shape = inputs_dict["x"].shape + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_from_pretrained_hub(self): + model, loading_info = GLIDETextToImageUNetModel.from_pretrained( + "fusing/unet-glide-text2im-dummy", output_loading_info=True + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + def test_output_pretrained(self): + model = GLIDETextToImageUNetModel.from_pretrained("fusing/unet-glide-text2im-dummy") + + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + noise = torch.randn((1, model.config.in_channels, model.config.resolution, model.config.resolution)).to( + torch_device + ) + emb = torch.randn((1, 16, model.config.transformer_dim)).to(torch_device) + time_step = torch.tensor([10] * noise.shape[0], device=torch_device) + + with torch.no_grad(): + output = model(noise, time_step, emb) + + output, _ = torch.split(output, 3, dim=1) + output_slice = output[0, -1, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([ 2.7766, -10.3558, -14.9149, -0.9376, -14.9175, -17.7679, -5.5565, -12.9521, -12.9845]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, atol=1e-3)) + + class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): model_class = UNetLDMModel