Add an alternative Karras et al. stochastic scheduler for VE models (#160)
* karras + VE, not flexible yet * Fix inputs incompatibility with the original unet * Roll back sigma scaling * Apply suggestions from code review * Old comment * Fix doc
This commit is contained in:
parent
543ee1e092
commit
dd10da76a7
|
@ -18,8 +18,15 @@ from .optimization import (
|
|||
get_scheduler,
|
||||
)
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
|
||||
from .pipelines import DDIMPipeline, DDPMPipeline, KarrasVePipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
DDPMScheduler,
|
||||
KarrasVeScheduler,
|
||||
PNDMScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from .ddpm import DDPMPipeline
|
|||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pndm import PNDMPipeline
|
||||
from .score_sde_ve import ScoreSdeVePipeline
|
||||
from .stochatic_karras_ve import KarrasVePipeline
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .pipeline_stochastic_karras_ve import KarrasVePipeline
|
|
@ -0,0 +1,80 @@
|
|||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import KarrasVeScheduler
|
||||
|
||||
|
||||
class KarrasVePipeline(DiffusionPipeline):
|
||||
"""
|
||||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
|
||||
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
|
||||
|
||||
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
|
||||
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
|
||||
"""
|
||||
|
||||
unet: UNet2DModel
|
||||
scheduler: KarrasVeScheduler
|
||||
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
scheduler = scheduler.set_format("pt")
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, num_inference_steps=50, generator=None, torch_device=None, output_type="pil"):
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
img_size = self.unet.config.sample_size
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.unet.to(torch_device)
|
||||
|
||||
# sample x_0 ~ N(0, sigma_0^2 * I)
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(torch_device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
# here sigma_t == t_i from the paper
|
||||
sigma = self.scheduler.schedule[t]
|
||||
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
|
||||
|
||||
# 1. Select temporarily increased noise level sigma_hat
|
||||
# 2. Add new noise to move from sample_i to sample_hat
|
||||
sample_hat, sigma_hat = self.scheduler.add_noise_to_input(sample, sigma, generator=generator)
|
||||
|
||||
# 3. Predict the noise residual given the noise magnitude `sigma_hat`
|
||||
# The model inputs and output are adjusted by following eq. (213) in [1].
|
||||
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]
|
||||
|
||||
# 4. Evaluate dx/dt at sigma_hat
|
||||
# 5. Take Euler step from sigma to sigma_prev
|
||||
step_output = self.scheduler.step(model_output, sigma_hat, sigma_prev, sample_hat)
|
||||
|
||||
if sigma_prev != 0:
|
||||
# 6. Apply 2nd order correction
|
||||
# The model inputs and output are adjusted by following eq. (213) in [1].
|
||||
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
|
||||
step_output = self.scheduler.step_correct(
|
||||
model_output,
|
||||
sigma_hat,
|
||||
sigma_prev,
|
||||
sample_hat,
|
||||
step_output["prev_sample"],
|
||||
step_output["derivative"],
|
||||
)
|
||||
sample = step_output["prev_sample"]
|
||||
|
||||
sample = (sample / 2 + 0.5).clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
||||
return {"sample": sample}
|
|
@ -18,6 +18,7 @@
|
|||
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
from .scheduling_ddpm import DDPMScheduler
|
||||
from .scheduling_karras_ve import KarrasVeScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
|
|
|
@ -134,7 +134,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
generator=None,
|
||||
):
|
||||
t = timestep
|
||||
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2022 NVIDIA and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2].
|
||||
Use Algorithm 2 and the VE column of Table 1 from [1] for reference.
|
||||
|
||||
[1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364
|
||||
[2] Song, Yang, et al. "Score-based generative modeling through stochastic differential equations." https://arxiv.org/abs/2011.13456
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sigma_min=0.02,
|
||||
sigma_max=100,
|
||||
s_noise=1.007,
|
||||
s_churn=80,
|
||||
s_min=0.05,
|
||||
s_max=50,
|
||||
tensor_format="pt",
|
||||
):
|
||||
"""
|
||||
For more details on the parameters, see the original paper's Appendix E.:
|
||||
"Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364.
|
||||
The grid search values used to find the optimal {s_noise, s_churn, s_min, s_max} for a specific model
|
||||
are described in Table 5 of the paper.
|
||||
|
||||
Args:
|
||||
sigma_min (`float`): minimum noise magnitude
|
||||
sigma_max (`float`): maximum noise magnitude
|
||||
s_noise (`float`): the amount of additional noise to counteract loss of detail during sampling.
|
||||
A reasonable range is [1.000, 1.011].
|
||||
s_churn (`float`): the parameter controlling the overall amount of stochasticity.
|
||||
A reasonable range is [0, 100].
|
||||
s_min (`float`): the start value of the sigma range where we add noise (enable stochasticity).
|
||||
A reasonable range is [0, 10].
|
||||
s_max (`float`): the end value of the sigma range where we add noise.
|
||||
A reasonable range is [0.2, 80].
|
||||
"""
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = None
|
||||
self.schedule = None # sigma(t_i)
|
||||
|
||||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.schedule = [
|
||||
(self.sigma_max * (self.sigma_min**2 / self.sigma_max**2) ** (i / (num_inference_steps - 1)))
|
||||
for i in self.timesteps
|
||||
]
|
||||
self.schedule = np.array(self.schedule, dtype=np.float32)
|
||||
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
def add_noise_to_input(self, sample, sigma, generator=None):
|
||||
"""
|
||||
Explicit Langevin-like "churn" step of adding noise to the sample according to
|
||||
a factor gamma_i ≥ 0 to reach a higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
|
||||
"""
|
||||
if self.s_min <= sigma <= self.s_max:
|
||||
gamma = min(self.s_churn / self.num_inference_steps, 2**0.5 - 1)
|
||||
else:
|
||||
gamma = 0
|
||||
|
||||
# sample eps ~ N(0, S_noise^2 * I)
|
||||
eps = self.s_noise * torch.randn(sample.shape, generator=generator).to(sample.device)
|
||||
sigma_hat = sigma + gamma * sigma
|
||||
sample_hat = sample + ((sigma_hat**2 - sigma**2) ** 0.5 * eps)
|
||||
|
||||
return sample_hat, sigma_hat
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sigma_hat: float,
|
||||
sigma_prev: float,
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
pred_original_sample = sample_hat + sigma_hat * model_output
|
||||
derivative = (sample_hat - pred_original_sample) / sigma_hat
|
||||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
|
||||
|
||||
return {"prev_sample": sample_prev, "derivative": derivative}
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sigma_hat: float,
|
||||
sigma_prev: float,
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
sample_prev: Union[torch.FloatTensor, np.ndarray],
|
||||
derivative: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
pred_original_sample = sample_prev + sigma_prev * model_output
|
||||
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
|
||||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
|
||||
return {"prev_sample": sample_prev, "derivative": derivative_corr}
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
raise NotImplementedError()
|
|
@ -29,6 +29,8 @@ from diffusers import (
|
|||
DDIMScheduler,
|
||||
DDPMPipeline,
|
||||
DDPMScheduler,
|
||||
KarrasVePipeline,
|
||||
KarrasVeScheduler,
|
||||
LDMPipeline,
|
||||
LDMTextToImagePipeline,
|
||||
PNDMPipeline,
|
||||
|
@ -909,3 +911,19 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_images - ddim_images).max() < 1e-1
|
||||
|
||||
@slow
|
||||
def test_karras_ve_pipeline(self):
|
||||
model_id = "google/ncsnpp-celebahq-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
scheduler = KarrasVeScheduler(tensor_format="pt")
|
||||
|
||||
pipe = KarrasVePipeline(unet=model, scheduler=scheduler)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = np.array([0.26815, 0.1581, 0.2658, 0.23248, 0.1550, 0.2539, 0.1131, 0.1024, 0.0837])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
|
Loading…
Reference in New Issue