diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index 61a80ff1..5ba5ddec 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -225,9 +225,6 @@ class ConfigMixin: text = reader.read() return json.loads(text) - def __eq__(self, other): - return self.__dict__ == other.__dict__ - def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" diff --git a/src/diffusers/pipelines/pipeline_glide.py b/src/diffusers/pipelines/pipeline_glide.py index 1f6495e8..138ce9d2 100644 --- a/src/diffusers/pipelines/pipeline_glide.py +++ b/src/diffusers/pipelines/pipeline_glide.py @@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline): # 1. Sample gaussian noise batch_size = 2 # second image is empty for classifier-free guidance - image = self.text_noise_scheduler.sample_noise( - (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator - ) + image = torch.randn( + (batch_size, self.text_unet.in_channels, 64, 64), generator=generator + ).to(torch_device) # 2. Encode tokens - # an empty input is needed to guide the model away from ( + # an empty input is needed to guide the model away from it inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") input_ids = inputs["input_ids"].to(torch_device) attention_mask = inputs["attention_mask"].to(torch_device) @@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline): mean, variance, log_variance, pred_xstart = self.p_mean_variance( text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out ) - noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) + noise = torch.randn(image.shape, generator=generator).to(torch_device) nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise @@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline): self.upscale_unet.resolution, ), generator=generator, - ) - image = image.to(torch_device) * upsample_temp + ).to(torch_device) + image = image * upsample_temp num_trained_timesteps = self.upscale_noise_scheduler.timesteps inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) @@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline): # 3. optionally sample variance variance = 0 if eta > 0: - noise = torch.randn(image.shape, generator=generator).to(image.device) + noise = torch.randn(image.shape, generator=generator).to(torch_device) variance = ( self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise ) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 417ef353..6db88316 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,7 @@ import unittest import torch -from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel +from diffusers import DDIM, DDPM, PNDM, GLIDE, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel from diffusers.configuration_utils import ConfigMixin from diffusers.pipeline_utils import DiffusionPipeline from diffusers.testing_utils import floats_tensor, slow, torch_device @@ -212,3 +212,18 @@ class PipelineTesterMixin(unittest.TestCase): assert image.shape == (1, 3, 256, 256) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 + + @slow + def test_glide_text2img(self): + model_id = "fusing/glide-base" + glide = GLIDE.from_pretrained(model_id) + + prompt = "a pencil sketch of a corgi" + generator = torch.manual_seed(0) + image = glide(prompt, generator=generator, num_inference_steps_upscale=20) + + image_slice = image[0, :3, :3, -1].cpu() + + assert image.shape == (1, 256, 256, 3) + expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784]) + assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2