This commit is contained in:
patil-suraj 2022-06-14 12:51:40 +02:00
commit 542c78686f
8 changed files with 285 additions and 173 deletions

View File

@ -3,7 +3,7 @@
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
export PYTHONPATH = src
check_dirs := tests src utils
check_dirs := examples tests src utils
modified_only_fixup:
$(eval modified_py_files := $(shell python utils/get_modified_files.py $(check_dirs)))

156
_ Normal file
View File

@ -0,0 +1,156 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
class PNDM(DiffusionPipeline):
def __init__(self, unet, noise_scheduler):
super().__init__()
noise_scheduler = noise_scheduler.set_format("pt")
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
image = torch.randn(
(batch_size, self.unet.in_channels, self.unet.resolution, self.unet.resolution),
generator=generator,
)
image = image.to(torch_device)
seq = list(inference_step_times)
seq_next = [-1] + list(seq[:-1])
model = self.unet
warmup_steps = [len(seq) - (i // 4 + 1) for i in range(3 * 4)]
ets = []
prev_image = image
for i, step_idx in enumerate(warmup_steps):
i = seq[step_idx]
j = seq_next[step_idx]
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
image = image.to("cpu")
image = self.noise_scheduler.transfer(prev_image.to("cpu"), t_list[0], t_list[1], residual)
if i % 4 == 0:
ets.append(residual)
prev_image = image
for
ets = []
step_idx = len(seq) - 1
while step_idx >= 0:
i = seq[step_idx]
j = seq_next[step_idx]
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
t_list = [t, (t+t_next)/2, t_next]
ets.append(residual)
if len(ets) <= 3:
image = image.to("cpu")
x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual)
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
else:
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
image = img_next
step_idx = step_idx - 1
# if len(prev_noises) in [1, 2]:
# t = (t + t_next) / 2
# elif len(prev_noises) == 3:
# t = t_next / 2
# if len(prev_noises) == 0:
# ets.append(residual)
#
# if len(ets) > 3:
# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) == 3:
# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual)
# prev_noises = []
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) < 3:
# prev_noises.append(residual)
# if len(prev_noises) < 2:
# t_next = (t + t_next) / 2
#
# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
return image
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
# with torch.no_grad():
# residual = self.unet(image, inference_step_times[t])
#
# 2. predict previous mean of image x_t-1
# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
#
# 3. optionally sample variance
# variance = 0
# if eta > 0:
# noise = torch.randn(image.shape, generator=generator).to(image.device)
# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
#
# 4. set current image to prev_image: x_t -> x_t-1
# image = pred_prev_image + variance

View File

@ -8,14 +8,23 @@ import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
from torchvision.transforms import CenterCrop, Compose, Lambda, RandomHorizontalFlip, Resize, ToTensor
from torchvision.transforms import (
Compose,
InterpolationMode,
Lambda,
RandomCrop,
RandomHorizontalFlip,
RandomVerticalFlip,
Resize,
ToTensor,
)
from tqdm.auto import tqdm
from transformers import get_linear_schedule_with_warmup
def set_seed(seed):
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
@ -30,13 +39,13 @@ model = UNetModel(
attn_resolutions=(16,),
ch=128,
ch_mult=(1, 2, 2, 2),
dropout=0.1,
dropout=0.0,
num_res_blocks=2,
resamp_with_conv=True,
resolution=32
resolution=32,
)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
num_epochs = 100
batch_size = 64
@ -44,14 +53,15 @@ gradient_accumulation_steps = 2
augmentations = Compose(
[
Resize(32),
CenterCrop(32),
Resize(32, interpolation=InterpolationMode.BILINEAR),
RandomHorizontalFlip(),
RandomVerticalFlip(),
RandomCrop(32),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset("huggan/pokemon", split="train")
dataset = load_dataset("huggan/flowers-102-categories", split="train")
def transforms(examples):
@ -59,24 +69,24 @@ def transforms(examples):
return {"input": images}
dataset = dataset.shuffle(seed=0)
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
#lr_scheduler = get_linear_schedule_with_warmup(
# optimizer=optimizer,
# num_warmup_steps=1000,
# num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
#)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=500,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
)
model, optimizer, train_dataloader = accelerator.prepare(
model, optimizer, train_dataloader
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(num_epochs):
model.train()
pbar = tqdm(total=len(train_dataloader), unit="ba")
pbar.set_description(f"Epoch {epoch}")
losses = []
for step, batch in enumerate(train_dataloader):
clean_images = batch["input"]
noisy_images = torch.empty_like(clean_images)
@ -101,10 +111,12 @@ for epoch in range(num_epochs):
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# lr_scheduler.step()
lr_scheduler.step()
optimizer.zero_grad()
loss = loss.detach().item()
losses.append(loss)
pbar.update(1)
pbar.set_postfix(loss=loss.detach().item(), lr=optimizer.param_groups[0]["lr"])
pbar.set_postfix(loss=loss, avg_loss=np.mean(losses), lr=optimizer.param_groups[0]["lr"])
optimizer.step()
@ -124,5 +136,5 @@ for epoch in range(num_epochs):
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
pipeline.save_pretrained("./poke-ddpm")
image_pil.save(f"./poke-ddpm/test_{epoch}.png")
pipeline.save_pretrained("./flowers-ddpm")
image_pil.save(f"./flowers-ddpm/test_{epoch}.png")

View File

@ -225,11 +225,8 @@ class ConfigMixin:
text = reader.read()
return json.loads(text)
# def __eq__(self, other):
# return self.__dict__ == other.__dict__
# def __repr__(self):
# return f"{self.__class__.__name__} {self.to_json_string()}"
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
@property
def config(self) -> Dict[str, Any]:

View File

@ -832,12 +832,12 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance
image = self.text_noise_scheduler.sample_noise(
(batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator
)
image = torch.randn(
(batch_size, self.text_unet.in_channels, 64, 64), generator=generator
).to(torch_device)
# 2. Encode tokens
# an empty input is needed to guide the model away from (
# an empty input is needed to guide the model away from it
inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt")
input_ids = inputs["input_ids"].to(torch_device)
attention_mask = inputs["attention_mask"].to(torch_device)
@ -850,7 +850,7 @@ class GLIDE(DiffusionPipeline):
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
)
noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator)
noise = torch.randn(image.shape, generator=generator).to(torch_device)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
@ -873,8 +873,8 @@ class GLIDE(DiffusionPipeline):
self.upscale_unet.resolution,
),
generator=generator,
)
image = image.to(torch_device) * upsample_temp
).to(torch_device)
image = image * upsample_temp
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
@ -896,7 +896,7 @@ class GLIDE(DiffusionPipeline):
# 3. optionally sample variance
variance = 0
if eta > 0:
noise = torch.randn(image.shape, generator=generator).to(image.device)
noise = torch.randn(image.shape, generator=generator).to(torch_device)
variance = (
self.upscale_noise_scheduler.get_variance(t, num_inference_steps_upscale).sqrt() * eta * noise
)

View File

@ -28,13 +28,11 @@ class PNDM(DiffusionPipeline):
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
def __call__(self, batch_size=1, generator=None, torch_device=None, num_inference_steps=50):
# eta corresponds to η in paper and should be between [0, 1]
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
if torch_device is None:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
num_trained_timesteps = self.noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps)
self.unet.to(torch_device)
# Sample gaussian noise to begin loop
@ -44,91 +42,18 @@ class PNDM(DiffusionPipeline):
)
image = image.to(torch_device)
seq = list(inference_step_times)
seq_next = [-1] + list(seq[:-1])
model = self.unet
warmup_time_steps = self.noise_scheduler.get_warmup_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(warmup_time_steps))):
t_orig = warmup_time_steps[t]
residual = self.unet(image, t_orig)
ets = []
prev_noises = []
step_idx = len(seq) - 1
while step_idx >= 0:
i = seq[step_idx]
j = seq_next[step_idx]
image = self.noise_scheduler.step_prk(residual, image, t, num_inference_steps)
t = (torch.ones(image.shape[0]) * i)
t_next = (torch.ones(image.shape[0]) * j)
timesteps = self.noise_scheduler.get_time_steps(num_inference_steps)
for t in tqdm.tqdm(range(len(timesteps))):
t_orig = timesteps[t]
residual = self.unet(image, t_orig)
residual = model(image.to("cuda"), t.to("cuda"))
residual = residual.to("cpu")
t_list = [t, (t+t_next)/2, t_next]
ets.append(residual)
if len(ets) <= 3:
image = image.to("cpu")
x_2 = self.noise_scheduler.transfer(image.to("cpu"), t_list[0], t_list[1], residual)
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_3 = self.noise_scheduler.transfer(image, t_list[0], t_list[1], e_2)
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda")).to("cpu")
x_4 = self.noise_scheduler.transfer(image, t_list[0], t_list[2], e_3)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda")).to("cpu")
residual = (1 / 6) * (residual + 2 * e_2 + 2 * e_3 + e_4)
else:
residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
img_next = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
image = img_next
step_idx = step_idx - 1
# if len(prev_noises) in [1, 2]:
# t = (t + t_next) / 2
# elif len(prev_noises) == 3:
# t = t_next / 2
# if len(prev_noises) == 0:
# ets.append(residual)
#
# if len(ets) > 3:
# residual = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) == 3:
# residual = (1 / 6) * (prev_noises[-3] + 2 * prev_noises[-2] + 2 * prev_noises[-1] + residual)
# prev_noises = []
# step_idx = step_idx - 1
# elif len(ets) <= 3 and len(prev_noises) < 3:
# prev_noises.append(residual)
# if len(prev_noises) < 2:
# t_next = (t + t_next) / 2
#
# image = self.noise_scheduler.transfer(image.to("cpu"), t, t_next, residual)
image = self.noise_scheduler.step_plms(residual, image, t, num_inference_steps)
return image
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_image -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_image_direction -> "direction pointingc to x_t"
# - pred_prev_image -> "x_t-1"
# for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
# 1. predict noise residual
# with torch.no_grad():
# residual = self.unet(image, inference_step_times[t])
#
# 2. predict previous mean of image x_t-1
# pred_prev_image = self.noise_scheduler.step(residual, image, t, num_inference_steps, eta)
#
# 3. optionally sample variance
# variance = 0
# if eta > 0:
# noise = torch.randn(image.shape, generator=generator).to(image.device)
# variance = self.noise_scheduler.get_variance(t, num_inference_steps).sqrt() * eta * noise
#
# 4. set current image to prev_image: x_t -> x_t-1
# image = pred_prev_image + variance

View File

@ -55,22 +55,17 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format)
# self.register_buffer("betas", betas.to(torch.float32))
# self.register_buffer("alphas", alphas.to(torch.float32))
# self.register_buffer("alphas_cumprod", alphas_cumprod.to(torch.float32))
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at equations (12) and (13) and the Algorithm 2.
self.pndm_order = 4
# alphas_cumprod_prev = torch.nn.functional.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# TODO(PVP) - check how much of these is actually necessary!
# LDM only uses "fixed_small"; glide seems to use a weird mix of the two, ...
# https://github.com/openai/glide-text2im/blob/69b530740eb6cef69442d6180579ef5ba9ef063e/glide_text2im/gaussian_diffusion.py#L246
# variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# if variance_type == "fixed_small":
# log_variance = torch.log(variance.clamp(min=1e-20))
# elif variance_type == "fixed_large":
# log_variance = torch.log(torch.cat([variance[1:2], betas[1:]], dim=0))
#
#
# self.register_buffer("log_variance", log_variance.to(torch.float32))
# running values
self.cur_residual = 0
self.cur_image = None
self.ets = []
self.warmup_time_steps = {}
self.time_steps = {}
def get_alpha(self, time_step):
return self.alphas[time_step]
@ -83,51 +78,64 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return self.one
return self.alphas_cumprod[time_step]
def step(self, img, t_start, t_end, model, ets):
# img_next = self.method(img_n, t_start, t_end, model, self.alphas_cump, self.ets)
#def gen_order_4(img, t, t_next, model, alphas_cump, ets):
t_next, t = t_start, t_end
def get_warmup_time_steps(self, num_inference_steps):
if num_inference_steps in self.warmup_time_steps:
return self.warmup_time_steps[num_inference_steps]
noise_ = model(img.to("cuda"), t.to("cuda"))
noise_ = noise_.to("cpu")
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
t_list = [t, (t+t_next)/2, t_next]
if len(ets) > 2:
ets.append(noise_)
noise = (1 / 24) * (55 * ets[-1] - 59 * ets[-2] + 37 * ets[-3] - 9 * ets[-4])
else:
noise = self.runge_kutta(img, t_list, model, ets, noise_)
warmup_time_steps = np.array(inference_step_times[-self.pndm_order:]).repeat(2) + np.tile(np.array([0, self.timesteps // num_inference_steps // 2]), self.pndm_order)
self.warmup_time_steps[num_inference_steps] = list(reversed(warmup_time_steps[:-1].repeat(2)[1:-1]))
img_next = self.transfer(img.to("cpu"), t, t_next, noise)
return img_next, ets
return self.warmup_time_steps[num_inference_steps]
def runge_kutta(self, x, t_list, model, ets, noise_):
model = model.to("cuda")
x = x.to("cpu")
def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]
e_1 = noise_
ets.append(e_1)
x_2 = self.transfer(x, t_list[0], t_list[1], e_1)
inference_step_times = list(range(0, self.timesteps, self.timesteps // num_inference_steps))
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
e_2 = model(x_2.to("cuda"), t_list[1].to("cuda"))
e_2 = e_2.to("cpu")
x_3 = self.transfer(x, t_list[0], t_list[1], e_2)
return self.time_steps[num_inference_steps]
e_3 = model(x_3.to("cuda"), t_list[1].to("cuda"))
e_3 = e_3.to("cpu")
x_4 = self.transfer(x, t_list[0], t_list[2], e_3)
def step_prk(self, residual, image, t, num_inference_steps):
# TODO(Patrick) - need to rethink whether the "warmup" way is the correct API design here
warmup_time_steps = self.get_warmup_time_steps(num_inference_steps)
e_4 = model(x_4.to("cuda"), t_list[2].to("cuda"))
e_4 = e_4.to("cpu")
t_prev = warmup_time_steps[t // 4 * 4]
t_next = warmup_time_steps[min(t + 1, len(warmup_time_steps) - 1)]
et = (1 / 6) * (e_1 + 2 * e_2 + 2 * e_3 + e_4)
if t % 4 == 0:
self.cur_residual += 1 / 6 * residual
self.ets.append(residual)
self.cur_image = image
elif (t - 1) % 4 == 0:
self.cur_residual += 1 / 3 * residual
elif (t - 2) % 4 == 0:
self.cur_residual += 1 / 3 * residual
elif (t - 3) % 4 == 0:
residual = self.cur_residual + 1 / 6 * residual
self.cur_residual = 0
return et
return self.transfer(self.cur_image, t_prev, t_next, residual)
def step_plms(self, residual, image, t, num_inference_steps):
timesteps = self.get_time_steps(num_inference_steps)
t_prev = timesteps[t]
t_next = timesteps[min(t + 1, len(timesteps) - 1)]
self.ets.append(residual)
residual = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
return self.transfer(image, t_prev, t_next, residual)
def transfer(self, x, t, t_next, et):
alphas_cump = self.alphas_cumprod
at = alphas_cump[t.long() + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next.long() + 1].view(-1, 1, 1, 1)
# TODO(Patrick): clean up to be compatible with numpy and give better names
alphas_cump = self.alphas_cumprod.to(x.device)
at = alphas_cump[t + 1].view(-1, 1, 1, 1)
at_next = alphas_cump[t_next + 1].view(-1, 1, 1, 1)
x_delta = (at_next - at) * ((1 / (at.sqrt() * (at.sqrt() + at_next.sqrt()))) * x - 1 / (at.sqrt() * (((1 - at_next) * at).sqrt() + ((1 - at) * at_next).sqrt())) * et)

View File

@ -19,7 +19,7 @@ import unittest
import torch
from diffusers import DDIM, DDPM, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
from diffusers import DDIM, DDPM, PNDM, GLIDE, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, PNDMScheduler, UNetModel
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave
@ -229,3 +229,17 @@ class PipelineTesterMixin(unittest.TestCase):
_ = BDDM.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname)
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2