refactor GLIDE text2im pipeline, remove classifier_free_guidance

This commit is contained in:
anton-l 2022-06-21 14:07:58 +02:00
parent 072d75196c
commit 9e31c6a749
8 changed files with 54 additions and 208 deletions

View File

@ -10,6 +10,7 @@ from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
from diffusers.hub_utils import init_git_repo, push_to_hub from diffusers.hub_utils import init_git_repo, push_to_hub
from diffusers.modeling_utils import unwrap_model from diffusers.modeling_utils import unwrap_model
from diffusers.optimization import get_scheduler
from diffusers.utils import logging from diffusers.utils import logging
from torchvision.transforms import ( from torchvision.transforms import (
CenterCrop, CenterCrop,
@ -21,7 +22,6 @@ from torchvision.transforms import (
ToTensor, ToTensor,
) )
from tqdm.auto import tqdm from tqdm.auto import tqdm
from diffusers.optimization import get_scheduler
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)

View File

@ -13,7 +13,6 @@ from .models.unet_rl import TemporalUNet
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import BDDM, DDIM, DDPM, PNDM from .pipelines import BDDM, DDIM, DDPM, PNDM
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
if is_transformers_available(): if is_transformers_available():

View File

@ -36,7 +36,6 @@ LOADABLE_CLASSES = {
"ModelMixin": ["save_pretrained", "from_pretrained"], "ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_config", "from_config"], "SchedulerMixin": ["save_config", "from_config"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
}, },
"transformers": { "transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],

View File

@ -32,7 +32,7 @@ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forwar
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler from ..schedulers import DDIMScheduler, DDPMScheduler
from ..utils import logging from ..utils import logging
@ -715,7 +715,7 @@ class GLIDE(DiffusionPipeline):
def __init__( def __init__(
self, self,
text_unet: GLIDETextToImageUNetModel, text_unet: GLIDETextToImageUNetModel,
text_noise_scheduler: ClassifierFreeGuidanceScheduler, text_noise_scheduler: DDPMScheduler,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: GPT2Tokenizer, tokenizer: GPT2Tokenizer,
upscale_unet: GLIDESuperResUNetModel, upscale_unet: GLIDESuperResUNetModel,
@ -731,100 +731,28 @@ class GLIDE(DiffusionPipeline):
upscale_noise_scheduler=upscale_noise_scheduler, upscale_noise_scheduler=upscale_noise_scheduler,
) )
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior:
q(x_{t-1} | x_t, x_0)
"""
assert x_start.shape == x_t.shape
posterior_mean = (
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
assert (
posterior_mean.shape[0]
== posterior_variance.shape[0]
== posterior_log_variance_clipped.shape[0]
== x_start.shape[0]
)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
"""
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
the initial x, x_0.
:param model: the model, which takes a signal and a batch of timesteps
as input.
:param x: the [N x C x ...] tensor at time t.
:param t: a 1-D Tensor of timesteps.
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
:param model_kwargs: if not None, a dict of extra keyword arguments to
pass to the model. This can be used for conditioning.
:return: a dict with the following keys:
- 'mean': the model mean output.
- 'variance': the model variance output.
- 'log_variance': the log of 'variance'.
- 'pred_xstart': the prediction for x_0.
"""
B, C = x.shape[:2]
assert t.shape == (B,)
if transformer_out is None:
# super-res model
model_output = model(x, t, low_res)
else:
# text2image model
model_output = model(x, t, transformer_out)
assert model_output.shape == (B, C * 2, *x.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
model_variance = torch.exp(model_log_variance)
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
if clip_denoised:
pred_xstart = pred_xstart.clamp(-1, 1)
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
return model_mean, model_variance, model_log_variance, pred_xstart
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
assert x_t.shape == eps.shape
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
)
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
return (
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
@torch.no_grad() @torch.no_grad()
def __call__(self, prompt, generator=None, torch_device=None, num_inference_steps_upscale=50): def __call__(
self,
prompt,
generator=None,
torch_device=None,
num_inference_steps_upscale=50,
guidance_scale=3.0,
eta=0.0,
upsample_temp=0.997,
):
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
self.text_unet.to(torch_device) self.text_unet.to(torch_device)
self.text_encoder.to(torch_device) self.text_encoder.to(torch_device)
self.upscale_unet.to(torch_device) self.upscale_unet.to(torch_device)
# Create a classifier-free guidance sampling function def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
guidance_scale = 3.0
def text_model_fn(x_t, ts, transformer_out, **kwargs):
half = x_t[: len(x_t) // 2] half = x_t[: len(x_t) // 2]
combined = torch.cat([half, half], dim=0) combined = torch.cat([half, half], dim=0)
model_out = self.text_unet(combined, ts, transformer_out, **kwargs) model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
eps, rest = model_out[:, :3], model_out[:, 3:] eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
@ -833,7 +761,15 @@ class GLIDE(DiffusionPipeline):
# 1. Sample gaussian noise # 1. Sample gaussian noise
batch_size = 2 # second image is empty for classifier-free guidance batch_size = 2 # second image is empty for classifier-free guidance
image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device) image = torch.randn(
(
batch_size,
self.text_unet.in_channels,
self.text_unet.resolution,
self.text_unet.resolution,
),
generator=generator,
).to(torch_device)
# 2. Encode tokens # 2. Encode tokens
# an empty input is needed to guide the model away from it # an empty input is needed to guide the model away from it
@ -843,25 +779,30 @@ class GLIDE(DiffusionPipeline):
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
# 3. Run the text2image generation step # 3. Run the text2image generation step
num_timesteps = len(self.text_noise_scheduler) num_prediction_steps = len(self.text_noise_scheduler)
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
t = torch.tensor([i] * image.shape[0], device=torch_device) with torch.no_grad():
mean, variance, log_variance, pred_xstart = self.p_mean_variance( time_input = torch.tensor([t] * image.shape[0], device=torch_device)
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out model_output = text_model_fn(image, time_input, transformer_out)
) noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
# The model_var_values is [-1, 1] for [min_var, max_var].
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t)
noise = torch.randn(image.shape, generator=generator).to(torch_device) noise = torch.randn(image.shape, generator=generator).to(torch_device)
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0 variance = torch.exp(0.5 * model_log_variance) * noise
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
# set current image to prev_image: x_t -> x_t-1
image = pred_prev_image + variance
# 4. Run the upscaling step # 4. Run the upscaling step
batch_size = 1 batch_size = 1
image = image[:1] image = image[:1]
low_res = ((image + 1) * 127.5).round() / 127.5 - 1 low_res = ((image + 1) * 127.5).round() / 127.5 - 1
eta = 0.0
# Tune this parameter to control the sharpness of 256x256 images.
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
upsample_temp = 0.997
# Sample gaussian noise to begin loop # Sample gaussian noise to begin loop
image = torch.randn( image = torch.randn(
@ -877,8 +818,6 @@ class GLIDE(DiffusionPipeline):
num_trained_timesteps = self.upscale_noise_scheduler.timesteps num_trained_timesteps = self.upscale_noise_scheduler.timesteps
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale) inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
# adapt the beta schedule to the number of steps
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale): for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
# 1. predict noise residual # 1. predict noise residual

View File

@ -16,7 +16,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
from .scheduling_ddim import DDIMScheduler from .scheduling_ddim import DDIMScheduler
from .scheduling_ddpm import DDPMScheduler from .scheduling_ddpm import DDPMScheduler
from .scheduling_grad_tts import GradTTSScheduler from .scheduling_grad_tts import GradTTSScheduler

View File

@ -1,96 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import torch
from torch import nn
from ..configuration_utils import ConfigMixin
SAMPLING_CONFIG_NAME = "scheduler_config.json"
def linear_beta_schedule(timesteps, beta_start, beta_end):
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas, dtype=np.float64)
class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
timesteps=1000,
beta_schedule="squaredcos_cap_v2",
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
beta_schedule=beta_schedule,
)
if beta_schedule == "squaredcos_cap_v2":
# GLIDE cosine schedule
self.betas = betas_for_alpha_bar(
timesteps,
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
alphas = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
# calculations for diffusion q(x_t | x_{t-1}) and others
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
# calculations for posterior q(x_{t-1} | x_t, x_0)
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1], self.posterior_variance[1:])
)
self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
def sample_noise(self, shape, device, generator=None):
# always sample on CPU to be deterministic
return torch.randn(shape, generator=generator).to(device)
def __len__(self):
return self.config.timesteps

View File

@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def get_variance(self, t): def get_variance(self, t, variance_type=None):
alpha_prod_t = self.alphas_cumprod[t] alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
@ -96,14 +96,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample # x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t] variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probs added for training stability # hacks - were probs added for training stability
if self.config.variance_type == "fixed_small": if variance_type == "fixed_small":
variance = self.clip(variance, min_value=1e-20) variance = self.clip(variance, min_value=1e-20)
# for rl-diffuser https://arxiv.org/abs/2205.09991 # for rl-diffuser https://arxiv.org/abs/2205.09991
elif self.config.variance_type == "fixed_small_log": elif variance_type == "fixed_small_log":
variance = self.log(self.clip(variance, min_value=1e-20)) variance = self.log(self.clip(variance, min_value=1e-20))
elif self.config.variance_type == "fixed_large": elif variance_type == "fixed_large":
variance = self.betas[t] variance = self.betas[t]
elif variance_type == "fixed_large_log":
# GLIDE max_log
variance = self.log(self.betas[t])
return variance return variance