refactor GLIDE text2im pipeline, remove classifier_free_guidance
This commit is contained in:
parent
072d75196c
commit
9e31c6a749
|
@ -10,6 +10,7 @@ from datasets import load_dataset
|
|||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||
from diffusers.modeling_utils import unwrap_model
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import logging
|
||||
from torchvision.transforms import (
|
||||
CenterCrop,
|
||||
|
@ -21,7 +22,6 @@ from torchvision.transforms import (
|
|||
ToTensor,
|
||||
)
|
||||
from tqdm.auto import tqdm
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
|
|
@ -13,7 +13,6 @@ from .models.unet_rl import TemporalUNet
|
|||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import BDDM, DDIM, DDPM, PNDM
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
|
|
|
@ -36,7 +36,6 @@ LOADABLE_CLASSES = {
|
|||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||
"SchedulerMixin": ["save_config", "from_config"],
|
||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
|
||||
},
|
||||
"transformers": {
|
||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||
|
|
|
@ -32,7 +32,7 @@ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forwar
|
|||
|
||||
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
|
||||
from ..schedulers import DDIMScheduler, DDPMScheduler
|
||||
from ..utils import logging
|
||||
|
||||
|
||||
|
@ -715,7 +715,7 @@ class GLIDE(DiffusionPipeline):
|
|||
def __init__(
|
||||
self,
|
||||
text_unet: GLIDETextToImageUNetModel,
|
||||
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
|
||||
text_noise_scheduler: DDPMScheduler,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: GPT2Tokenizer,
|
||||
upscale_unet: GLIDESuperResUNetModel,
|
||||
|
@ -731,100 +731,28 @@ class GLIDE(DiffusionPipeline):
|
|||
upscale_noise_scheduler=upscale_noise_scheduler,
|
||||
)
|
||||
|
||||
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
||||
"""
|
||||
Compute the mean and variance of the diffusion posterior:
|
||||
|
||||
q(x_{t-1} | x_t, x_0)
|
||||
|
||||
"""
|
||||
assert x_start.shape == x_t.shape
|
||||
posterior_mean = (
|
||||
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
||||
assert (
|
||||
posterior_mean.shape[0]
|
||||
== posterior_variance.shape[0]
|
||||
== posterior_log_variance_clipped.shape[0]
|
||||
== x_start.shape[0]
|
||||
)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
|
||||
"""
|
||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
||||
the initial x, x_0.
|
||||
|
||||
:param model: the model, which takes a signal and a batch of timesteps
|
||||
as input.
|
||||
:param x: the [N x C x ...] tensor at time t.
|
||||
:param t: a 1-D Tensor of timesteps.
|
||||
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
||||
pass to the model. This can be used for conditioning.
|
||||
:return: a dict with the following keys:
|
||||
- 'mean': the model mean output.
|
||||
- 'variance': the model variance output.
|
||||
- 'log_variance': the log of 'variance'.
|
||||
- 'pred_xstart': the prediction for x_0.
|
||||
"""
|
||||
|
||||
B, C = x.shape[:2]
|
||||
assert t.shape == (B,)
|
||||
if transformer_out is None:
|
||||
# super-res model
|
||||
model_output = model(x, t, low_res)
|
||||
else:
|
||||
# text2image model
|
||||
model_output = model(x, t, transformer_out)
|
||||
|
||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
||||
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
||||
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
|
||||
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
model_variance = torch.exp(model_log_variance)
|
||||
|
||||
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
|
||||
if clip_denoised:
|
||||
pred_xstart = pred_xstart.clamp(-1, 1)
|
||||
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
|
||||
|
||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
||||
return model_mean, model_variance, model_log_variance, pred_xstart
|
||||
|
||||
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
|
||||
assert x_t.shape == eps.shape
|
||||
return (
|
||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
||||
)
|
||||
|
||||
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
|
||||
return (
|
||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
||||
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, prompt, generator=None, torch_device=None, num_inference_steps_upscale=50):
|
||||
def __call__(
|
||||
self,
|
||||
prompt,
|
||||
generator=None,
|
||||
torch_device=None,
|
||||
num_inference_steps_upscale=50,
|
||||
guidance_scale=3.0,
|
||||
eta=0.0,
|
||||
upsample_temp=0.997,
|
||||
):
|
||||
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.text_unet.to(torch_device)
|
||||
self.text_encoder.to(torch_device)
|
||||
self.upscale_unet.to(torch_device)
|
||||
|
||||
# Create a classifier-free guidance sampling function
|
||||
guidance_scale = 3.0
|
||||
|
||||
def text_model_fn(x_t, ts, transformer_out, **kwargs):
|
||||
def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
|
||||
half = x_t[: len(x_t) // 2]
|
||||
combined = torch.cat([half, half], dim=0)
|
||||
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
|
||||
model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
|
||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
||||
|
@ -833,7 +761,15 @@ class GLIDE(DiffusionPipeline):
|
|||
|
||||
# 1. Sample gaussian noise
|
||||
batch_size = 2 # second image is empty for classifier-free guidance
|
||||
image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device)
|
||||
image = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
self.text_unet.in_channels,
|
||||
self.text_unet.resolution,
|
||||
self.text_unet.resolution,
|
||||
),
|
||||
generator=generator,
|
||||
).to(torch_device)
|
||||
|
||||
# 2. Encode tokens
|
||||
# an empty input is needed to guide the model away from it
|
||||
|
@ -843,25 +779,30 @@ class GLIDE(DiffusionPipeline):
|
|||
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
||||
|
||||
# 3. Run the text2image generation step
|
||||
num_timesteps = len(self.text_noise_scheduler)
|
||||
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
|
||||
t = torch.tensor([i] * image.shape[0], device=torch_device)
|
||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
|
||||
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
|
||||
)
|
||||
num_prediction_steps = len(self.text_noise_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
with torch.no_grad():
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
model_output = text_model_fn(image, time_input, transformer_out)
|
||||
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
|
||||
|
||||
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
|
||||
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
|
||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||
frac = (model_var_values + 1) / 2
|
||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t)
|
||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
||||
variance = torch.exp(0.5 * model_log_variance) * noise
|
||||
|
||||
# set current image to prev_image: x_t -> x_t-1
|
||||
image = pred_prev_image + variance
|
||||
|
||||
# 4. Run the upscaling step
|
||||
batch_size = 1
|
||||
image = image[:1]
|
||||
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
||||
eta = 0.0
|
||||
|
||||
# Tune this parameter to control the sharpness of 256x256 images.
|
||||
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
||||
upsample_temp = 0.997
|
||||
|
||||
# Sample gaussian noise to begin loop
|
||||
image = torch.randn(
|
||||
|
@ -877,8 +818,6 @@ class GLIDE(DiffusionPipeline):
|
|||
|
||||
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||
# adapt the beta schedule to the number of steps
|
||||
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
|
||||
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
||||
# 1. predict noise residual
|
||||
|
|
|
@ -16,7 +16,6 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_grad_tts import GradTTSScheduler
|
||||
|
|
|
@ -1,96 +0,0 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
|
||||
|
||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas, dtype=np.float64)
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timesteps=1000,
|
||||
beta_schedule="squaredcos_cap_v2",
|
||||
):
|
||||
super().__init__()
|
||||
self.register_to_config(
|
||||
timesteps=timesteps,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
|
||||
if beta_schedule == "squaredcos_cap_v2":
|
||||
# GLIDE cosine schedule
|
||||
self.betas = betas_for_alpha_bar(
|
||||
timesteps,
|
||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
||||
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
self.posterior_log_variance_clipped = np.log(
|
||||
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
||||
)
|
||||
self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
||||
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
||||
|
||||
def sample_noise(self, shape, device, generator=None):
|
||||
# always sample on CPU to be deterministic
|
||||
return torch.randn(shape, generator=generator).to(device)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.timesteps
|
|
@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def get_variance(self, t):
|
||||
def get_variance(self, t, variance_type=None):
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||
|
||||
|
@ -96,14 +96,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
|
||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
||||
|
||||
if variance_type is None:
|
||||
variance_type = self.config.variance_type
|
||||
|
||||
# hacks - were probs added for training stability
|
||||
if self.config.variance_type == "fixed_small":
|
||||
if variance_type == "fixed_small":
|
||||
variance = self.clip(variance, min_value=1e-20)
|
||||
# for rl-diffuser https://arxiv.org/abs/2205.09991
|
||||
elif self.config.variance_type == "fixed_small_log":
|
||||
elif variance_type == "fixed_small_log":
|
||||
variance = self.log(self.clip(variance, min_value=1e-20))
|
||||
elif self.config.variance_type == "fixed_large":
|
||||
elif variance_type == "fixed_large":
|
||||
variance = self.betas[t]
|
||||
elif variance_type == "fixed_large_log":
|
||||
# GLIDE max_log
|
||||
variance = self.log(self.betas[t])
|
||||
|
||||
return variance
|
||||
|
||||
|
|
Loading…
Reference in New Issue