Improve dynamic thresholding and extend to DDPM and DDIM Schedulers (#2528)

* Improve dynamic threshold

* Update code

* Add dynamic threshold to ddim and ddpm

* Encapsulate and leverage code copy mechanism

Update style

* Clean up DDPM/DDIM constructor arguments

* add test

* also add to unipc

---------

Co-authored-by: Peter Lin <peterlin9863@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
clarencechen 2023-03-07 14:10:26 -08:00 committed by GitHub
parent 46bef6e31d
commit 55660cfb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 171 additions and 60 deletions

View File

@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.

View File

@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.

View File

@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
class_embed_type (`str`, *optional*, defaults to None):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.

View File

@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@ -128,6 +139,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@ -184,6 +199,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@ -286,9 +313,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
" `v_prediction`"
)
# 4. Clip "predicted x_0"
# 4. Clip or threshold "predicted x_0"
if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)

View File

@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability.
option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
"""
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
@ -119,7 +130,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
clip_sample_range: Optional[float] = 1.0,
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
):
if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@ -226,6 +240,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def step(
self,
model_output: torch.FloatTensor,
@ -283,12 +308,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
" `v_prediction` for the DDPMScheduler."
)
# 3. Clip "predicted x_0"
# 3. Clip or threshold "predicted x_0"
if self.config.clip_sample:
pred_original_sample = torch.clamp(
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t

View File

@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487).
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid woks when `thresholding=True`
the threshold value for dynamic thresholding. Valid only when `thresholding=True`
algorithm_type (`str`, default `deis`):
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
the future
@ -194,6 +194,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order
self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
@ -228,15 +240,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile(
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
if self.config.algorithm_type == "deis":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]

View File

@ -204,6 +204,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order
self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
@ -247,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile(
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":

View File

@ -237,6 +237,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None
self.orders = self.get_order_list(num_inference_steps)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
@ -277,18 +289,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dtype = x0_pred.dtype
dynamic_max_val = torch.quantile(
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(),
self.config.dynamic_thresholding_ratio,
dim=1,
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.to(dtype)
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver":

View File

@ -109,7 +109,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
@ -129,8 +130,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional):
final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
"""
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min

View File

@ -116,7 +116,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
"""
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
@ -143,8 +144,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional):
final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
"""
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max

View File

@ -210,6 +210,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.solver_p:
self.solver_p.set_timesteps(num_inference_steps, device=device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor:
@ -245,15 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile(
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
return x0_pred
else:
if self.config.prediction_type == "epsilon":

View File

@ -647,6 +647,16 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type)
@ -791,6 +801,16 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_time_indices(self):
for t in [1, 10, 49]:
self.check_over_forward(time_step=t)
@ -1212,6 +1232,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.3301) < 1e-3
def test_full_loop_no_noise_thres(self):
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.6405) < 1e-3
def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample))