refactor GLIDE text2im pipeline, remove classifier_free_guidance
This commit is contained in:
parent
072d75196c
commit
9e31c6a749
|
@ -10,6 +10,7 @@ from datasets import load_dataset
|
||||||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||||
from diffusers.hub_utils import init_git_repo, push_to_hub
|
from diffusers.hub_utils import init_git_repo, push_to_hub
|
||||||
from diffusers.modeling_utils import unwrap_model
|
from diffusers.modeling_utils import unwrap_model
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
from diffusers.utils import logging
|
from diffusers.utils import logging
|
||||||
from torchvision.transforms import (
|
from torchvision.transforms import (
|
||||||
CenterCrop,
|
CenterCrop,
|
||||||
|
@ -21,7 +22,6 @@ from torchvision.transforms import (
|
||||||
ToTensor,
|
ToTensor,
|
||||||
)
|
)
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from diffusers.optimization import get_scheduler
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
|
@ -13,7 +13,6 @@ from .models.unet_rl import TemporalUNet
|
||||||
from .pipeline_utils import DiffusionPipeline
|
from .pipeline_utils import DiffusionPipeline
|
||||||
from .pipelines import BDDM, DDIM, DDPM, PNDM
|
from .pipelines import BDDM, DDIM, DDPM, PNDM
|
||||||
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
from .schedulers import DDIMScheduler, DDPMScheduler, GradTTSScheduler, PNDMScheduler, SchedulerMixin
|
||||||
from .schedulers.classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
|
|
|
@ -36,7 +36,6 @@ LOADABLE_CLASSES = {
|
||||||
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
"ModelMixin": ["save_pretrained", "from_pretrained"],
|
||||||
"SchedulerMixin": ["save_config", "from_config"],
|
"SchedulerMixin": ["save_config", "from_config"],
|
||||||
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
|
||||||
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
|
|
||||||
},
|
},
|
||||||
"transformers": {
|
"transformers": {
|
||||||
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
|
||||||
|
|
|
@ -32,7 +32,7 @@ from transformers.utils import ModelOutput, add_start_docstrings_to_model_forwar
|
||||||
|
|
||||||
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
from ..models import GLIDESuperResUNetModel, GLIDETextToImageUNetModel
|
||||||
from ..pipeline_utils import DiffusionPipeline
|
from ..pipeline_utils import DiffusionPipeline
|
||||||
from ..schedulers import ClassifierFreeGuidanceScheduler, DDIMScheduler
|
from ..schedulers import DDIMScheduler, DDPMScheduler
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
@ -715,7 +715,7 @@ class GLIDE(DiffusionPipeline):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
text_unet: GLIDETextToImageUNetModel,
|
text_unet: GLIDETextToImageUNetModel,
|
||||||
text_noise_scheduler: ClassifierFreeGuidanceScheduler,
|
text_noise_scheduler: DDPMScheduler,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: GPT2Tokenizer,
|
tokenizer: GPT2Tokenizer,
|
||||||
upscale_unet: GLIDESuperResUNetModel,
|
upscale_unet: GLIDESuperResUNetModel,
|
||||||
|
@ -731,100 +731,28 @@ class GLIDE(DiffusionPipeline):
|
||||||
upscale_noise_scheduler=upscale_noise_scheduler,
|
upscale_noise_scheduler=upscale_noise_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
def q_posterior_mean_variance(self, scheduler, x_start, x_t, t):
|
|
||||||
"""
|
|
||||||
Compute the mean and variance of the diffusion posterior:
|
|
||||||
|
|
||||||
q(x_{t-1} | x_t, x_0)
|
|
||||||
|
|
||||||
"""
|
|
||||||
assert x_start.shape == x_t.shape
|
|
||||||
posterior_mean = (
|
|
||||||
_extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start
|
|
||||||
+ _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t
|
|
||||||
)
|
|
||||||
posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape)
|
|
||||||
posterior_log_variance_clipped = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x_t.shape)
|
|
||||||
assert (
|
|
||||||
posterior_mean.shape[0]
|
|
||||||
== posterior_variance.shape[0]
|
|
||||||
== posterior_log_variance_clipped.shape[0]
|
|
||||||
== x_start.shape[0]
|
|
||||||
)
|
|
||||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
||||||
|
|
||||||
def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True):
|
|
||||||
"""
|
|
||||||
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
|
||||||
the initial x, x_0.
|
|
||||||
|
|
||||||
:param model: the model, which takes a signal and a batch of timesteps
|
|
||||||
as input.
|
|
||||||
:param x: the [N x C x ...] tensor at time t.
|
|
||||||
:param t: a 1-D Tensor of timesteps.
|
|
||||||
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
|
||||||
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
|
||||||
pass to the model. This can be used for conditioning.
|
|
||||||
:return: a dict with the following keys:
|
|
||||||
- 'mean': the model mean output.
|
|
||||||
- 'variance': the model variance output.
|
|
||||||
- 'log_variance': the log of 'variance'.
|
|
||||||
- 'pred_xstart': the prediction for x_0.
|
|
||||||
"""
|
|
||||||
|
|
||||||
B, C = x.shape[:2]
|
|
||||||
assert t.shape == (B,)
|
|
||||||
if transformer_out is None:
|
|
||||||
# super-res model
|
|
||||||
model_output = model(x, t, low_res)
|
|
||||||
else:
|
|
||||||
# text2image model
|
|
||||||
model_output = model(x, t, transformer_out)
|
|
||||||
|
|
||||||
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
|
||||||
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
|
||||||
min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape)
|
|
||||||
max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape)
|
|
||||||
# The model_var_values is [-1, 1] for [min_var, max_var].
|
|
||||||
frac = (model_var_values + 1) / 2
|
|
||||||
model_log_variance = frac * max_log + (1 - frac) * min_log
|
|
||||||
model_variance = torch.exp(model_log_variance)
|
|
||||||
|
|
||||||
pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output)
|
|
||||||
if clip_denoised:
|
|
||||||
pred_xstart = pred_xstart.clamp(-1, 1)
|
|
||||||
model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t)
|
|
||||||
|
|
||||||
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
|
||||||
return model_mean, model_variance, model_log_variance, pred_xstart
|
|
||||||
|
|
||||||
def _predict_xstart_from_eps(self, scheduler, x_t, t, eps):
|
|
||||||
assert x_t.shape == eps.shape
|
|
||||||
return (
|
|
||||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
|
||||||
- _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
|
||||||
)
|
|
||||||
|
|
||||||
def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart):
|
|
||||||
return (
|
|
||||||
_extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
|
||||||
) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, prompt, generator=None, torch_device=None, num_inference_steps_upscale=50):
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt,
|
||||||
|
generator=None,
|
||||||
|
torch_device=None,
|
||||||
|
num_inference_steps_upscale=50,
|
||||||
|
guidance_scale=3.0,
|
||||||
|
eta=0.0,
|
||||||
|
upsample_temp=0.997,
|
||||||
|
):
|
||||||
|
|
||||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
self.text_unet.to(torch_device)
|
self.text_unet.to(torch_device)
|
||||||
self.text_encoder.to(torch_device)
|
self.text_encoder.to(torch_device)
|
||||||
self.upscale_unet.to(torch_device)
|
self.upscale_unet.to(torch_device)
|
||||||
|
|
||||||
# Create a classifier-free guidance sampling function
|
def text_model_fn(x_t, timesteps, transformer_out, **kwargs):
|
||||||
guidance_scale = 3.0
|
|
||||||
|
|
||||||
def text_model_fn(x_t, ts, transformer_out, **kwargs):
|
|
||||||
half = x_t[: len(x_t) // 2]
|
half = x_t[: len(x_t) // 2]
|
||||||
combined = torch.cat([half, half], dim=0)
|
combined = torch.cat([half, half], dim=0)
|
||||||
model_out = self.text_unet(combined, ts, transformer_out, **kwargs)
|
model_out = self.text_unet(combined, timesteps, transformer_out, **kwargs)
|
||||||
eps, rest = model_out[:, :3], model_out[:, 3:]
|
eps, rest = model_out[:, :3], model_out[:, 3:]
|
||||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||||
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
||||||
|
@ -833,7 +761,15 @@ class GLIDE(DiffusionPipeline):
|
||||||
|
|
||||||
# 1. Sample gaussian noise
|
# 1. Sample gaussian noise
|
||||||
batch_size = 2 # second image is empty for classifier-free guidance
|
batch_size = 2 # second image is empty for classifier-free guidance
|
||||||
image = torch.randn((batch_size, self.text_unet.in_channels, 64, 64), generator=generator).to(torch_device)
|
image = torch.randn(
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
self.text_unet.in_channels,
|
||||||
|
self.text_unet.resolution,
|
||||||
|
self.text_unet.resolution,
|
||||||
|
),
|
||||||
|
generator=generator,
|
||||||
|
).to(torch_device)
|
||||||
|
|
||||||
# 2. Encode tokens
|
# 2. Encode tokens
|
||||||
# an empty input is needed to guide the model away from it
|
# an empty input is needed to guide the model away from it
|
||||||
|
@ -843,25 +779,30 @@ class GLIDE(DiffusionPipeline):
|
||||||
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state
|
||||||
|
|
||||||
# 3. Run the text2image generation step
|
# 3. Run the text2image generation step
|
||||||
num_timesteps = len(self.text_noise_scheduler)
|
num_prediction_steps = len(self.text_noise_scheduler)
|
||||||
for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps):
|
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||||
t = torch.tensor([i] * image.shape[0], device=torch_device)
|
with torch.no_grad():
|
||||||
mean, variance, log_variance, pred_xstart = self.p_mean_variance(
|
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||||
text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out
|
model_output = text_model_fn(image, time_input, transformer_out)
|
||||||
)
|
noise_residual, model_var_values = torch.split(model_output, 3, dim=1)
|
||||||
|
|
||||||
|
min_log = self.text_noise_scheduler.get_variance(t, "fixed_small_log")
|
||||||
|
max_log = self.text_noise_scheduler.get_variance(t, "fixed_large_log")
|
||||||
|
# The model_var_values is [-1, 1] for [min_var, max_var].
|
||||||
|
frac = (model_var_values + 1) / 2
|
||||||
|
model_log_variance = frac * max_log + (1 - frac) * min_log
|
||||||
|
|
||||||
|
pred_prev_image = self.text_noise_scheduler.step(noise_residual, image, t)
|
||||||
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
noise = torch.randn(image.shape, generator=generator).to(torch_device)
|
||||||
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) # no noise when t == 0
|
variance = torch.exp(0.5 * model_log_variance) * noise
|
||||||
image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise
|
|
||||||
|
# set current image to prev_image: x_t -> x_t-1
|
||||||
|
image = pred_prev_image + variance
|
||||||
|
|
||||||
# 4. Run the upscaling step
|
# 4. Run the upscaling step
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
image = image[:1]
|
image = image[:1]
|
||||||
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
low_res = ((image + 1) * 127.5).round() / 127.5 - 1
|
||||||
eta = 0.0
|
|
||||||
|
|
||||||
# Tune this parameter to control the sharpness of 256x256 images.
|
|
||||||
# A value of 1.0 is sharper, but sometimes results in grainy artifacts.
|
|
||||||
upsample_temp = 0.997
|
|
||||||
|
|
||||||
# Sample gaussian noise to begin loop
|
# Sample gaussian noise to begin loop
|
||||||
image = torch.randn(
|
image = torch.randn(
|
||||||
|
@ -877,8 +818,6 @@ class GLIDE(DiffusionPipeline):
|
||||||
|
|
||||||
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
num_trained_timesteps = self.upscale_noise_scheduler.timesteps
|
||||||
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
inference_step_times = range(0, num_trained_timesteps, num_trained_timesteps // num_inference_steps_upscale)
|
||||||
# adapt the beta schedule to the number of steps
|
|
||||||
# self.upscale_noise_scheduler.rescale_betas(num_inference_steps_upscale)
|
|
||||||
|
|
||||||
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
for t in tqdm.tqdm(reversed(range(num_inference_steps_upscale)), total=num_inference_steps_upscale):
|
||||||
# 1. predict noise residual
|
# 1. predict noise residual
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from .classifier_free_guidance import ClassifierFreeGuidanceScheduler
|
|
||||||
from .scheduling_ddim import DDIMScheduler
|
from .scheduling_ddim import DDIMScheduler
|
||||||
from .scheduling_ddpm import DDPMScheduler
|
from .scheduling_ddpm import DDPMScheduler
|
||||||
from .scheduling_grad_tts import GradTTSScheduler
|
from .scheduling_grad_tts import GradTTSScheduler
|
||||||
|
|
|
@ -1,96 +0,0 @@
|
||||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import math
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLING_CONFIG_NAME = "scheduler_config.json"
|
|
||||||
|
|
||||||
|
|
||||||
def linear_beta_schedule(timesteps, beta_start, beta_end):
|
|
||||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
|
||||||
|
|
||||||
|
|
||||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
|
||||||
"""
|
|
||||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
|
||||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
|
||||||
|
|
||||||
:param num_diffusion_timesteps: the number of betas to produce.
|
|
||||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
|
||||||
produces the cumulative product of (1-beta) up to that
|
|
||||||
part of the diffusion process.
|
|
||||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
|
||||||
prevent singularities.
|
|
||||||
"""
|
|
||||||
betas = []
|
|
||||||
for i in range(num_diffusion_timesteps):
|
|
||||||
t1 = i / num_diffusion_timesteps
|
|
||||||
t2 = (i + 1) / num_diffusion_timesteps
|
|
||||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
|
||||||
return np.array(betas, dtype=np.float64)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassifierFreeGuidanceScheduler(nn.Module, ConfigMixin):
|
|
||||||
|
|
||||||
config_name = SAMPLING_CONFIG_NAME
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
timesteps=1000,
|
|
||||||
beta_schedule="squaredcos_cap_v2",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.register_to_config(
|
|
||||||
timesteps=timesteps,
|
|
||||||
beta_schedule=beta_schedule,
|
|
||||||
)
|
|
||||||
|
|
||||||
if beta_schedule == "squaredcos_cap_v2":
|
|
||||||
# GLIDE cosine schedule
|
|
||||||
self.betas = betas_for_alpha_bar(
|
|
||||||
timesteps,
|
|
||||||
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
|
||||||
|
|
||||||
alphas = 1.0 - self.betas
|
|
||||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
||||||
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
|
||||||
|
|
||||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
|
||||||
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
|
||||||
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
|
||||||
|
|
||||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
|
||||||
self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
|
||||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
|
||||||
self.posterior_log_variance_clipped = np.log(
|
|
||||||
np.append(self.posterior_variance[1], self.posterior_variance[1:])
|
|
||||||
)
|
|
||||||
self.posterior_mean_coef1 = self.betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
|
||||||
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
|
||||||
|
|
||||||
def sample_noise(self, shape, device, generator=None):
|
|
||||||
# always sample on CPU to be deterministic
|
|
||||||
return torch.randn(shape, generator=generator).to(device)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return self.config.timesteps
|
|
|
@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
self.set_format(tensor_format=tensor_format)
|
self.set_format(tensor_format=tensor_format)
|
||||||
|
|
||||||
def get_variance(self, t):
|
def get_variance(self, t, variance_type=None):
|
||||||
alpha_prod_t = self.alphas_cumprod[t]
|
alpha_prod_t = self.alphas_cumprod[t]
|
||||||
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
alpha_prod_t_prev = self.alphas_cumprod[t - 1] if t > 0 else self.one
|
||||||
|
|
||||||
|
@ -96,14 +96,20 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
|
# x_{t-1} ~ N(pred_prev_sample, variance) == add variane to pred_sample
|
||||||
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[t]
|
||||||
|
|
||||||
|
if variance_type is None:
|
||||||
|
variance_type = self.config.variance_type
|
||||||
|
|
||||||
# hacks - were probs added for training stability
|
# hacks - were probs added for training stability
|
||||||
if self.config.variance_type == "fixed_small":
|
if variance_type == "fixed_small":
|
||||||
variance = self.clip(variance, min_value=1e-20)
|
variance = self.clip(variance, min_value=1e-20)
|
||||||
# for rl-diffuser https://arxiv.org/abs/2205.09991
|
# for rl-diffuser https://arxiv.org/abs/2205.09991
|
||||||
elif self.config.variance_type == "fixed_small_log":
|
elif variance_type == "fixed_small_log":
|
||||||
variance = self.log(self.clip(variance, min_value=1e-20))
|
variance = self.log(self.clip(variance, min_value=1e-20))
|
||||||
elif self.config.variance_type == "fixed_large":
|
elif variance_type == "fixed_large":
|
||||||
variance = self.betas[t]
|
variance = self.betas[t]
|
||||||
|
elif variance_type == "fixed_large_log":
|
||||||
|
# GLIDE max_log
|
||||||
|
variance = self.log(self.betas[t])
|
||||||
|
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue