Merge branch 'main' of https://github.com/huggingface/diffusers into main
This commit is contained in:
commit
9b7e6f495f
|
@ -225,9 +225,6 @@ class ConfigMixin:
|
|||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
|
|
|
@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline):
|
|||
|
||||
# 1. Sample gaussian noise
|
||||
batch_size = 2 # second image is empty for classifier-free guidance
|
||||
image = self.text_noise_scheduler.sample_noise(
|
||||
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
|
||||
)
|
||||
image = torch.randn(
|
||||
(batch_size, self.text_unet.in_channels, 64, 64), generator=generator
|
||||
).to(torch_device)
|
||||
|
||||
# 2. Encode tokens
|
||||
# an empty input is needed to guide the model away from (
|
||||
# an empty input is needed to guide the model away from it
|
||||
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].to(torch_device)
|
||||
attention_mask = inputs["attention_mask"].to(torch_device)
|
||||
|
@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline):
|
|||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
|
||||
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
|
||||
)
|
||||
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
|
||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
|
||||
|
@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline):
|
|||
self.upscale_unet.resolution,
|
||||
),
|
||||
generator=generator,
|
||||
)
|
||||
image = image.to(torch_device) * upsample_temp
|
||||
).to(torch_device)
|
||||
image = image * upsample_temp
|
||||
|
||||
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||
|
@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline):
|
|||
# 3. optionally sample variance
|
||||
variance = 0
|
||||
if eta > 0:
|
||||
noise = torch.randn(image.shape, generator=generator).to(image.device)
|
||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||
variance = (
|
||||
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
|
||||
)
|
||||
|
|
|
@ -19,7 +19,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from diffusers import DDIM, DDPM, PNDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
|
||||
from diffusers import DDIM, DDPM, PNDM, GLIDE, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
|
||||
from diffusers.configuration_utils import ConfigMixin
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
@ -212,3 +212,18 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
assert image.shape == (1, 3, 256, 256)
|
||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_glide_text2img(self):
|
||||
model_id = "fusing/glide-base"
|
||||
glide = GLIDE.from_pretrained(model_id)
|
||||
|
||||
prompt = "a pencil sketch of a corgi"
|
||||
generator = torch.manual_seed(0)
|
||||
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
|
||||
|
||||
image_slice = image[0, :3, :3, -1].cpu()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
|
Loading…
Reference in New Issue