This commit is contained in:
Patrick von Platen 2022-06-14 10:39:25 +00:00
commit 9b7e6f495f
3 changed files with 24 additions and 12 deletions

View File

@ -225,9 +225,6 @@ class ConfigMixin:
text = reader.read()
return json.loads(text)
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"

View File

@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance
image = self.text_noise_scheduler.sample_noise(
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
)
image = torch.randn(
(batch_size, self.text_unet.in_channels, 64, 64), generator=generator
).to(torch_device)
# 2. Encode tokens
# an empty input is needed to guide the model away from (
# an empty input is needed to guide the model away from it
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
input_ids = inputs["input_ids"].to(torch_device)
attention_mask = inputs["attention_mask"].to(torch_device)
@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline):
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
)
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
noise = torch.randn(image.shape, generator=generator).to(torch_device)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline):
self.upscale_unet.resolution,
),
generator=generator,
)
image = image.to(torch_device) * upsample_temp
).to(torch_device)
image = image * upsample_temp
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline):
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = (
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
)

View File

@ -19,7 +19,7 @@ import unittest
import torch
from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
from diffusers import DDIM, DDPM, PNDM, GLIDE, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.testing_utils import floats_tensor, slow, torch_device
@ -212,3 +212,18 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2