k-diffusion-euler (#1019)
* k-diffusion-euler * make style make quality * make fix-copies * fix tests for euler a * Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * Update src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * remove unused arg and method * update doc * quality * make flake happy * use logger instead of warn * raise error instead of deprication * don't require scipy * pass generator in step * fix tests * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update tests/test_scheduler.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * remove unused generator * pass generator as extra_step_kwargs * update tests * pass generator as kwarg * pass generator as kwarg * quality * fix test for lms * fix tests Co-authored-by: patil-suraj <surajp815@gmail.com> Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
bf7b0bc25b
commit
a1ea8c01c3
|
@ -41,6 +41,8 @@ if is_torch_available():
|
||||||
from .schedulers import (
|
from .schedulers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
IPNDMScheduler,
|
IPNDMScheduler,
|
||||||
KarrasVeScheduler,
|
KarrasVeScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
|
|
|
@ -9,7 +9,13 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
)
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
@ -52,7 +58,9 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
scheduler: Union[
|
||||||
|
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
|
||||||
|
],
|
||||||
safety_checker: StableDiffusionSafetyChecker,
|
safety_checker: StableDiffusionSafetyChecker,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
):
|
):
|
||||||
|
@ -334,6 +342,11 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
|
@ -10,7 +10,13 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
from ...configuration_utils import FrozenDict
|
from ...configuration_utils import FrozenDict
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
from ...schedulers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
)
|
||||||
from ...utils import deprecate, logging
|
from ...utils import deprecate, logging
|
||||||
from . import StableDiffusionPipelineOutput
|
from . import StableDiffusionPipelineOutput
|
||||||
from .safety_checker import StableDiffusionSafetyChecker
|
from .safety_checker import StableDiffusionSafetyChecker
|
||||||
|
@ -63,7 +69,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
tokenizer: CLIPTokenizer,
|
tokenizer: CLIPTokenizer,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
scheduler: Union[
|
||||||
|
DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
|
||||||
|
],
|
||||||
safety_checker: StableDiffusionSafetyChecker,
|
safety_checker: StableDiffusionSafetyChecker,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
):
|
):
|
||||||
|
@ -335,6 +343,11 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
|
||||||
latents = init_latents
|
latents = init_latents
|
||||||
|
|
||||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||||
|
|
|
@ -379,6 +379,11 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
|
||||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
|
|
@ -352,6 +352,11 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
|
# check if the scheduler accepts generator
|
||||||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||||
|
if accepts_generator:
|
||||||
|
extra_step_kwargs["generator"] = generator
|
||||||
|
|
||||||
latents = init_latents
|
latents = init_latents
|
||||||
|
|
||||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||||
|
|
|
@ -19,6 +19,8 @@ from ..utils import is_flax_available, is_scipy_available, is_torch_available
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .scheduling_ddim import DDIMScheduler
|
from .scheduling_ddim import DDIMScheduler
|
||||||
from .scheduling_ddpm import DDPMScheduler
|
from .scheduling_ddpm import DDPMScheduler
|
||||||
|
from .scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler
|
||||||
|
from .scheduling_euler_discrete import EulerDiscreteScheduler
|
||||||
from .scheduling_ipndm import IPNDMScheduler
|
from .scheduling_ipndm import IPNDMScheduler
|
||||||
from .scheduling_karras_ve import KarrasVeScheduler
|
from .scheduling_karras_ve import KarrasVeScheduler
|
||||||
from .scheduling_pndm import PNDMScheduler
|
from .scheduling_pndm import PNDMScheduler
|
||||||
|
|
|
@ -0,0 +1,261 @@
|
||||||
|
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..utils import BaseOutput, deprecate, logging
|
||||||
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete
|
||||||
|
class EulerAncestralDiscreteSchedulerOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for the scheduler's step function output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||||
|
denoising loop.
|
||||||
|
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||||
|
`pred_original_sample` can be used to preview progress or for guidance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prev_sample: torch.FloatTensor
|
||||||
|
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson:
|
||||||
|
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72
|
||||||
|
|
||||||
|
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||||
|
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||||
|
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||||
|
[`~ConfigMixin.from_config`] functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||||
|
beta_start (`float`): the starting `beta` value of inference.
|
||||||
|
beta_end (`float`): the final `beta` value.
|
||||||
|
beta_schedule (`str`):
|
||||||
|
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||||
|
`linear` or `scaled_linear`.
|
||||||
|
trained_betas (`np.ndarray`, optional):
|
||||||
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_train_timesteps: int = 1000,
|
||||||
|
beta_start: float = 0.0001,
|
||||||
|
beta_end: float = 0.02,
|
||||||
|
beta_schedule: str = "linear",
|
||||||
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
|
):
|
||||||
|
if trained_betas is not None:
|
||||||
|
self.betas = torch.from_numpy(trained_betas)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||||
|
elif beta_schedule == "scaled_linear":
|
||||||
|
# this schedule is very specific to the latent diffusion model.
|
||||||
|
self.betas = (
|
||||||
|
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
self.alphas = 1.0 - self.betas
|
||||||
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||||
|
|
||||||
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||||
|
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
||||||
|
self.sigmas = torch.from_numpy(sigmas)
|
||||||
|
|
||||||
|
# standard deviation of the initial noise distribution
|
||||||
|
self.init_noise_sigma = self.sigmas.max()
|
||||||
|
|
||||||
|
# setable values
|
||||||
|
self.num_inference_steps = None
|
||||||
|
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
||||||
|
self.timesteps = torch.from_numpy(timesteps)
|
||||||
|
self.is_scale_input_called = False
|
||||||
|
|
||||||
|
def scale_model_input(
|
||||||
|
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): input sample
|
||||||
|
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: scaled input sample
|
||||||
|
"""
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
step_index = (self.timesteps == timestep).nonzero().item()
|
||||||
|
sigma = self.sigmas[step_index]
|
||||||
|
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||||
|
self.is_scale_input_called = True
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||||
|
"""
|
||||||
|
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
device (`str` or `torch.device`, optional):
|
||||||
|
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
"""
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||||
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||||
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||||
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||||
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||||
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: torch.FloatTensor,
|
||||||
|
timestep: Union[float, torch.FloatTensor],
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
|
process from the learned model outputs (most often the predicted noise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||||
|
timestep (`float`): current timestep in the diffusion chain.
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
current instance of sample being created by diffusion process.
|
||||||
|
generator (`torch.Generator`, optional): Random number generator.
|
||||||
|
return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`:
|
||||||
|
[`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise
|
||||||
|
a `tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(timestep, int)
|
||||||
|
or isinstance(timestep, torch.IntTensor)
|
||||||
|
or isinstance(timestep, torch.LongTensor)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||||
|
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||||
|
" one of the `scheduler.timesteps` as a timestep.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.is_scale_input_called:
|
||||||
|
logger.warn(
|
||||||
|
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||||
|
"See `StableDiffusionPipeline` for a usage example."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
|
||||||
|
step_index = (self.timesteps == timestep).nonzero().item()
|
||||||
|
sigma = self.sigmas[step_index]
|
||||||
|
|
||||||
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||||
|
pred_original_sample = sample - sigma * model_output
|
||||||
|
sigma_from = self.sigmas[step_index]
|
||||||
|
sigma_to = self.sigmas[step_index + 1]
|
||||||
|
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
||||||
|
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||||
|
|
||||||
|
# 2. Convert to an ODE derivative
|
||||||
|
derivative = (sample - pred_original_sample) / sigma
|
||||||
|
|
||||||
|
dt = sigma_down - sigma
|
||||||
|
|
||||||
|
prev_sample = sample + derivative * dt
|
||||||
|
|
||||||
|
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||||
|
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
||||||
|
prev_sample = prev_sample + noise * sigma_up
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (prev_sample,)
|
||||||
|
|
||||||
|
return EulerAncestralDiscreteSchedulerOutput(
|
||||||
|
prev_sample=prev_sample, pred_original_sample=pred_original_sample
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
original_samples: torch.FloatTensor,
|
||||||
|
noise: torch.FloatTensor,
|
||||||
|
timesteps: torch.FloatTensor,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
|
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
# mps does not support float64
|
||||||
|
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||||
|
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
self.timesteps = self.timesteps.to(original_samples.device)
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
|
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
||||||
|
deprecate(
|
||||||
|
"timesteps as indices",
|
||||||
|
"0.8.0",
|
||||||
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||||
|
" `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
|
||||||
|
" pass values from `scheduler.timesteps` as timesteps.",
|
||||||
|
standard_warn=False,
|
||||||
|
)
|
||||||
|
step_indices = timesteps
|
||||||
|
else:
|
||||||
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
|
||||||
|
sigma = self.sigmas[step_indices].flatten()
|
||||||
|
while len(sigma.shape) < len(original_samples.shape):
|
||||||
|
sigma = sigma.unsqueeze(-1)
|
||||||
|
|
||||||
|
noisy_samples = original_samples + noise * sigma
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.config.num_train_timesteps
|
|
@ -0,0 +1,270 @@
|
||||||
|
# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from ..utils import BaseOutput, deprecate, logging
|
||||||
|
from .scheduling_utils import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete
|
||||||
|
class EulerDiscreteSchedulerOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
Output class for the scheduler's step function output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||||
|
denoising loop.
|
||||||
|
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||||
|
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||||
|
`pred_original_sample` can be used to preview progress or for guidance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prev_sample: torch.FloatTensor
|
||||||
|
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
|
||||||
|
k-diffusion implementation by Katherine Crowson:
|
||||||
|
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
|
||||||
|
|
||||||
|
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||||
|
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||||
|
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||||
|
[`~ConfigMixin.from_config`] functions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||||
|
beta_start (`float`): the starting `beta` value of inference.
|
||||||
|
beta_end (`float`): the final `beta` value.
|
||||||
|
beta_schedule (`str`):
|
||||||
|
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
|
||||||
|
`linear` or `scaled_linear`.
|
||||||
|
trained_betas (`np.ndarray`, optional):
|
||||||
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_train_timesteps: int = 1000,
|
||||||
|
beta_start: float = 0.0001,
|
||||||
|
beta_end: float = 0.02,
|
||||||
|
beta_schedule: str = "linear",
|
||||||
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
|
):
|
||||||
|
if trained_betas is not None:
|
||||||
|
self.betas = torch.from_numpy(trained_betas)
|
||||||
|
elif beta_schedule == "linear":
|
||||||
|
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
|
||||||
|
elif beta_schedule == "scaled_linear":
|
||||||
|
# this schedule is very specific to the latent diffusion model.
|
||||||
|
self.betas = (
|
||||||
|
torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||||
|
|
||||||
|
self.alphas = 1.0 - self.betas
|
||||||
|
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||||
|
|
||||||
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||||
|
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
||||||
|
self.sigmas = torch.from_numpy(sigmas)
|
||||||
|
|
||||||
|
# standard deviation of the initial noise distribution
|
||||||
|
self.init_noise_sigma = self.sigmas.max()
|
||||||
|
|
||||||
|
# setable values
|
||||||
|
self.num_inference_steps = None
|
||||||
|
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
||||||
|
self.timesteps = torch.from_numpy(timesteps)
|
||||||
|
self.is_scale_input_called = False
|
||||||
|
|
||||||
|
def scale_model_input(
|
||||||
|
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
"""
|
||||||
|
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`): input sample
|
||||||
|
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`torch.FloatTensor`: scaled input sample
|
||||||
|
"""
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
step_index = (self.timesteps == timestep).nonzero().item()
|
||||||
|
sigma = self.sigmas[step_index]
|
||||||
|
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||||
|
self.is_scale_input_called = True
|
||||||
|
return sample
|
||||||
|
|
||||||
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||||
|
"""
|
||||||
|
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_inference_steps (`int`):
|
||||||
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
|
device (`str` or `torch.device`, optional):
|
||||||
|
the device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||||
|
"""
|
||||||
|
self.num_inference_steps = num_inference_steps
|
||||||
|
|
||||||
|
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||||
|
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||||
|
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
||||||
|
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||||
|
self.sigmas = torch.from_numpy(sigmas).to(device=device)
|
||||||
|
self.timesteps = torch.from_numpy(timesteps).to(device=device)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
model_output: torch.FloatTensor,
|
||||||
|
timestep: Union[float, torch.FloatTensor],
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
s_churn: float = 0.0,
|
||||||
|
s_tmin: float = 0.0,
|
||||||
|
s_tmax: float = float("inf"),
|
||||||
|
s_noise: float = 1.0,
|
||||||
|
generator: Optional[torch.Generator] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[EulerDiscreteSchedulerOutput, Tuple]:
|
||||||
|
"""
|
||||||
|
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||||
|
process from the learned model outputs (most often the predicted noise).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||||
|
timestep (`float`): current timestep in the diffusion chain.
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
current instance of sample being created by diffusion process.
|
||||||
|
s_churn (`float`)
|
||||||
|
s_tmin (`float`)
|
||||||
|
s_tmax (`float`)
|
||||||
|
s_noise (`float`)
|
||||||
|
generator (`torch.Generator`, optional): Random number generator.
|
||||||
|
return_dict (`bool`): option for returning tuple rather than EulerDiscreteSchedulerOutput class
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] or `tuple`:
|
||||||
|
[`~schedulers.scheduling_utils.EulerDiscreteSchedulerOutput`] if `return_dict` is True, otherwise a
|
||||||
|
`tuple`. When returning a tuple, the first element is the sample tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
isinstance(timestep, int)
|
||||||
|
or isinstance(timestep, torch.IntTensor)
|
||||||
|
or isinstance(timestep, torch.LongTensor)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||||
|
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||||
|
" one of the `scheduler.timesteps` as a timestep.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.is_scale_input_called:
|
||||||
|
logger.warn(
|
||||||
|
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||||
|
"See `StableDiffusionPipeline` for a usage example."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(timestep, torch.Tensor):
|
||||||
|
timestep = timestep.to(self.timesteps.device)
|
||||||
|
|
||||||
|
step_index = (self.timesteps == timestep).nonzero().item()
|
||||||
|
sigma = self.sigmas[step_index]
|
||||||
|
|
||||||
|
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
||||||
|
|
||||||
|
device = model_output.device if torch.is_tensor(model_output) else "cpu"
|
||||||
|
noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator).to(device)
|
||||||
|
eps = noise * s_noise
|
||||||
|
sigma_hat = sigma * (gamma + 1)
|
||||||
|
|
||||||
|
if gamma > 0:
|
||||||
|
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||||
|
|
||||||
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||||
|
pred_original_sample = sample - sigma_hat * model_output
|
||||||
|
|
||||||
|
# 2. Convert to an ODE derivative
|
||||||
|
derivative = (sample - pred_original_sample) / sigma_hat
|
||||||
|
|
||||||
|
dt = self.sigmas[step_index + 1] - sigma_hat
|
||||||
|
|
||||||
|
prev_sample = sample + derivative * dt
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (prev_sample,)
|
||||||
|
|
||||||
|
return EulerDiscreteSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
|
||||||
|
|
||||||
|
def add_noise(
|
||||||
|
self,
|
||||||
|
original_samples: torch.FloatTensor,
|
||||||
|
noise: torch.FloatTensor,
|
||||||
|
timesteps: torch.FloatTensor,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
||||||
|
self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||||
|
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
|
||||||
|
# mps does not support float64
|
||||||
|
self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
|
||||||
|
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
self.timesteps = self.timesteps.to(original_samples.device)
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
|
schedule_timesteps = self.timesteps
|
||||||
|
|
||||||
|
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
||||||
|
deprecate(
|
||||||
|
"timesteps as indices",
|
||||||
|
"0.8.0",
|
||||||
|
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||||
|
" `EulerDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to"
|
||||||
|
" pass values from `scheduler.timesteps` as timesteps.",
|
||||||
|
standard_warn=False,
|
||||||
|
)
|
||||||
|
step_indices = timesteps
|
||||||
|
else:
|
||||||
|
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||||
|
|
||||||
|
sigma = self.sigmas[step_indices].flatten()
|
||||||
|
while len(sigma.shape) < len(original_samples.shape):
|
||||||
|
sigma = sigma.unsqueeze(-1)
|
||||||
|
|
||||||
|
noisy_samples = original_samples + noise * sigma
|
||||||
|
return noisy_samples
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.config.num_train_timesteps
|
|
@ -272,6 +272,36 @@ class DDPMScheduler(metaclass=DummyObject):
|
||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EulerAncestralDiscreteScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EulerDiscreteScheduler(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class IPNDMScheduler(metaclass=DummyObject):
|
class IPNDMScheduler(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,8 @@ import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
AutoencoderKL,
|
AutoencoderKL,
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
StableDiffusionPipeline,
|
StableDiffusionPipeline,
|
||||||
|
@ -361,6 +363,96 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_k_euler_ancestral(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = EulerAncestralDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||||
|
|
||||||
|
image = output.images
|
||||||
|
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
image_from_tuple = sd_pipe(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 128, 128, 3)
|
||||||
|
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
|
def test_stable_diffusion_k_euler(self):
|
||||||
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
unet = self.dummy_cond_unet
|
||||||
|
scheduler = EulerDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear")
|
||||||
|
vae = self.dummy_vae
|
||||||
|
bert = self.dummy_text_encoder
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||||
|
|
||||||
|
# make sure here that pndm scheduler skips prk
|
||||||
|
sd_pipe = StableDiffusionPipeline(
|
||||||
|
unet=unet,
|
||||||
|
scheduler=scheduler,
|
||||||
|
vae=vae,
|
||||||
|
text_encoder=bert,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=self.dummy_extractor,
|
||||||
|
)
|
||||||
|
sd_pipe = sd_pipe.to(device)
|
||||||
|
sd_pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
|
prompt = "A painting of a squirrel eating a burger"
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||||
|
|
||||||
|
image = output.images
|
||||||
|
|
||||||
|
generator = torch.Generator(device=device).manual_seed(0)
|
||||||
|
image_from_tuple = sd_pipe(
|
||||||
|
[prompt],
|
||||||
|
generator=generator,
|
||||||
|
guidance_scale=6.0,
|
||||||
|
num_inference_steps=2,
|
||||||
|
output_type="np",
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
image_slice = image[0, -3:, -3:, -1]
|
||||||
|
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||||
|
|
||||||
|
assert image.shape == (1, 128, 128, 3)
|
||||||
|
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
|
||||||
|
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||||
|
|
||||||
def test_stable_diffusion_attention_chunk(self):
|
def test_stable_diffusion_attention_chunk(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
unet = self.dummy_cond_unet
|
unet = self.dummy_cond_unet
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import inspect
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
@ -22,6 +23,8 @@ import torch
|
||||||
from diffusers import (
|
from diffusers import (
|
||||||
DDIMScheduler,
|
DDIMScheduler,
|
||||||
DDPMScheduler,
|
DDPMScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
IPNDMScheduler,
|
IPNDMScheduler,
|
||||||
LMSDiscreteScheduler,
|
LMSDiscreteScheduler,
|
||||||
PNDMScheduler,
|
PNDMScheduler,
|
||||||
|
@ -77,7 +80,11 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
|
|
||||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||||
|
|
||||||
|
# TODO(Suraj) - delete the following two lines once DDPM, DDIM, and PNDM have timesteps casted to float by default
|
||||||
for scheduler_class in self.scheduler_classes:
|
for scheduler_class in self.scheduler_classes:
|
||||||
|
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||||
|
time_step = float(time_step)
|
||||||
|
|
||||||
sample = self.dummy_sample
|
sample = self.dummy_sample
|
||||||
residual = 0.1 * sample
|
residual = 0.1 * sample
|
||||||
|
|
||||||
|
@ -94,7 +101,13 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
|
# Set the seed before step() as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||||
|
@ -106,6 +119,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||||
|
|
||||||
for scheduler_class in self.scheduler_classes:
|
for scheduler_class in self.scheduler_classes:
|
||||||
|
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||||
|
time_step = float(time_step)
|
||||||
|
|
||||||
sample = self.dummy_sample
|
sample = self.dummy_sample
|
||||||
residual = 0.1 * sample
|
residual = 0.1 * sample
|
||||||
|
|
||||||
|
@ -122,9 +138,12 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
torch.manual_seed(0)
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
torch.manual_seed(0)
|
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||||
|
@ -141,6 +160,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
scheduler_config = self.get_scheduler_config()
|
scheduler_config = self.get_scheduler_config()
|
||||||
scheduler = scheduler_class(**scheduler_config)
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
|
timestep = 1
|
||||||
|
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||||
|
timestep = float(timestep)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
scheduler.save_config(tmpdirname)
|
scheduler.save_config(tmpdirname)
|
||||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||||
|
@ -151,10 +174,13 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
torch.manual_seed(0)
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
torch.manual_seed(0)
|
output = scheduler.step(residual, timestep, sample, **kwargs).prev_sample
|
||||||
new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
|
new_output = new_scheduler.step(residual, timestep, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||||
|
|
||||||
|
@ -163,7 +189,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
|
|
||||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||||
|
|
||||||
|
timestep_0 = 0
|
||||||
|
timestep_1 = 1
|
||||||
|
|
||||||
for scheduler_class in self.scheduler_classes:
|
for scheduler_class in self.scheduler_classes:
|
||||||
|
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||||
|
timestep_0 = float(timestep_0)
|
||||||
|
timestep_1 = float(timestep_1)
|
||||||
|
|
||||||
scheduler_config = self.get_scheduler_config()
|
scheduler_config = self.get_scheduler_config()
|
||||||
scheduler = scheduler_class(**scheduler_config)
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
|
@ -175,8 +208,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
|
output_0 = scheduler.step(residual, timestep_0, sample, **kwargs).prev_sample
|
||||||
output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
output_1 = scheduler.step(residual, timestep_1, sample, **kwargs).prev_sample
|
||||||
|
|
||||||
self.assertEqual(output_0.shape, sample.shape)
|
self.assertEqual(output_0.shape, sample.shape)
|
||||||
self.assertEqual(output_0.shape, output_1.shape)
|
self.assertEqual(output_0.shape, output_1.shape)
|
||||||
|
@ -216,6 +249,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
timestep = 1
|
timestep = 1
|
||||||
|
|
||||||
for scheduler_class in self.scheduler_classes:
|
for scheduler_class in self.scheduler_classes:
|
||||||
|
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||||
|
timestep = float(timestep)
|
||||||
|
|
||||||
scheduler_config = self.get_scheduler_config()
|
scheduler_config = self.get_scheduler_config()
|
||||||
scheduler = scheduler_class(**scheduler_config)
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
|
@ -227,6 +263,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
|
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
|
outputs_dict = scheduler.step(residual, timestep, sample, **kwargs)
|
||||||
|
|
||||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||||
|
@ -234,6 +273,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||||
kwargs["num_inference_steps"] = num_inference_steps
|
kwargs["num_inference_steps"] = num_inference_steps
|
||||||
|
|
||||||
|
# Set the seed before state as some schedulers are stochastic like EulerAncestralDiscreteScheduler, EulerDiscreteScheduler
|
||||||
|
if "generator" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
kwargs["generator"] = torch.Generator().manual_seed(0)
|
||||||
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
|
outputs_tuple = scheduler.step(residual, timestep, sample, return_dict=False, **kwargs)
|
||||||
|
|
||||||
recursive_check(outputs_tuple, outputs_dict)
|
recursive_check(outputs_tuple, outputs_dict)
|
||||||
|
@ -933,6 +975,117 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
assert abs(result_mean.item() - 1.31) < 1e-3
|
assert abs(result_mean.item() - 1.31) < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
class EulerDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
scheduler_classes = (EulerDiscreteScheduler,)
|
||||||
|
num_inference_steps = 10
|
||||||
|
|
||||||
|
def get_scheduler_config(self, **kwargs):
|
||||||
|
config = {
|
||||||
|
"num_train_timesteps": 1100,
|
||||||
|
"beta_start": 0.0001,
|
||||||
|
"beta_end": 0.02,
|
||||||
|
"beta_schedule": "linear",
|
||||||
|
"trained_betas": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
config.update(**kwargs)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def test_timesteps(self):
|
||||||
|
for timesteps in [10, 50, 100, 1000]:
|
||||||
|
self.check_over_configs(num_train_timesteps=timesteps)
|
||||||
|
|
||||||
|
def test_betas(self):
|
||||||
|
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
|
||||||
|
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||||
|
|
||||||
|
def test_schedules(self):
|
||||||
|
for schedule in ["linear", "scaled_linear"]:
|
||||||
|
self.check_over_configs(beta_schedule=schedule)
|
||||||
|
|
||||||
|
def test_full_loop_no_noise(self):
|
||||||
|
scheduler_class = self.scheduler_classes[0]
|
||||||
|
scheduler_config = self.get_scheduler_config()
|
||||||
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
|
generator = torch.Generator().manual_seed(0)
|
||||||
|
|
||||||
|
model = self.dummy_model()
|
||||||
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
|
||||||
|
for i, t in enumerate(scheduler.timesteps):
|
||||||
|
sample = scheduler.scale_model_input(sample, t)
|
||||||
|
|
||||||
|
model_output = model(sample, t)
|
||||||
|
|
||||||
|
output = scheduler.step(model_output, t, sample, generator=generator)
|
||||||
|
sample = output.prev_sample
|
||||||
|
|
||||||
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
print(result_sum, result_mean)
|
||||||
|
|
||||||
|
assert abs(result_sum.item() - 10.0807) < 1e-2
|
||||||
|
assert abs(result_mean.item() - 0.0131) < 1e-3
|
||||||
|
|
||||||
|
|
||||||
|
class EulerAncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||||
|
scheduler_classes = (EulerAncestralDiscreteScheduler,)
|
||||||
|
num_inference_steps = 10
|
||||||
|
|
||||||
|
def get_scheduler_config(self, **kwargs):
|
||||||
|
config = {
|
||||||
|
"num_train_timesteps": 1100,
|
||||||
|
"beta_start": 0.0001,
|
||||||
|
"beta_end": 0.02,
|
||||||
|
"beta_schedule": "linear",
|
||||||
|
"trained_betas": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
config.update(**kwargs)
|
||||||
|
return config
|
||||||
|
|
||||||
|
def test_timesteps(self):
|
||||||
|
for timesteps in [10, 50, 100, 1000]:
|
||||||
|
self.check_over_configs(num_train_timesteps=timesteps)
|
||||||
|
|
||||||
|
def test_betas(self):
|
||||||
|
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
|
||||||
|
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
|
||||||
|
|
||||||
|
def test_schedules(self):
|
||||||
|
for schedule in ["linear", "scaled_linear"]:
|
||||||
|
self.check_over_configs(beta_schedule=schedule)
|
||||||
|
|
||||||
|
def test_full_loop_no_noise(self):
|
||||||
|
scheduler_class = self.scheduler_classes[0]
|
||||||
|
scheduler_config = self.get_scheduler_config()
|
||||||
|
scheduler = scheduler_class(**scheduler_config)
|
||||||
|
|
||||||
|
scheduler.set_timesteps(self.num_inference_steps)
|
||||||
|
|
||||||
|
generator = torch.Generator().manual_seed(0)
|
||||||
|
|
||||||
|
model = self.dummy_model()
|
||||||
|
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||||
|
|
||||||
|
for i, t in enumerate(scheduler.timesteps):
|
||||||
|
sample = scheduler.scale_model_input(sample, t)
|
||||||
|
|
||||||
|
model_output = model(sample, t)
|
||||||
|
|
||||||
|
output = scheduler.step(model_output, t, sample, generator=generator)
|
||||||
|
sample = output.prev_sample
|
||||||
|
|
||||||
|
result_sum = torch.sum(torch.abs(sample))
|
||||||
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
print(result_sum, result_mean)
|
||||||
|
assert abs(result_sum.item() - 152.3192) < 1e-2
|
||||||
|
assert abs(result_mean.item() - 0.1983) < 1e-3
|
||||||
|
|
||||||
|
|
||||||
class IPNDMSchedulerTest(SchedulerCommonTest):
|
class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||||
scheduler_classes = (IPNDMScheduler,)
|
scheduler_classes = (IPNDMScheduler,)
|
||||||
forward_default_kwargs = (("num_inference_steps", 50),)
|
forward_default_kwargs = (("num_inference_steps", 50),)
|
||||||
|
|
Loading…
Reference in New Issue