Improve dynamic thresholding and extend to DDPM and DDIM Schedulers (#2528)
* Improve dynamic threshold * Update code * Add dynamic threshold to ddim and ddpm * Encapsulate and leverage code copy mechanism Update style * Clean up DDPM/DDIM constructor arguments * add test * also add to unipc --------- Co-authored-by: Peter Lin <peterlin9863@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
46bef6e31d
commit
55660cfb6d
|
@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
||||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
class_embed_type (`str`, *optional*, defaults to None):
|
||||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||||
|
`"timestep"`, or `"identity"`.
|
||||||
num_class_embeds (`int`, *optional*, defaults to None):
|
num_class_embeds (`int`, *optional*, defaults to None):
|
||||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
class conditioning with `class_embed_type` equal to `None`.
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
|
|
@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
class_embed_type (`str`, *optional*, defaults to None):
|
||||||
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||||
|
`"timestep"`, `"identity"`, or `"projection"`.
|
||||||
num_class_embeds (`int`, *optional*, defaults to None):
|
num_class_embeds (`int`, *optional*, defaults to None):
|
||||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
class conditioning with `class_embed_type` equal to `None`.
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
|
|
@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||||
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
||||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
class_embed_type (`str`, *optional*, defaults to None):
|
||||||
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||||
|
`"timestep"`, `"identity"`, or `"projection"`.
|
||||||
num_class_embeds (`int`, *optional*, defaults to None):
|
num_class_embeds (`int`, *optional*, defaults to None):
|
||||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
class conditioning with `class_embed_type` equal to `None`.
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
|
|
@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
trained_betas (`np.ndarray`, optional):
|
trained_betas (`np.ndarray`, optional):
|
||||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||||
clip_sample (`bool`, default `True`):
|
clip_sample (`bool`, default `True`):
|
||||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
option to clip predicted sample for numerical stability.
|
||||||
|
clip_sample_range (`float`, default `1.0`):
|
||||||
|
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||||
set_alpha_to_one (`bool`, default `True`):
|
set_alpha_to_one (`bool`, default `True`):
|
||||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||||
|
@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||||
https://imagen.research.google/video/paper.pdf)
|
https://imagen.research.google/video/paper.pdf)
|
||||||
|
thresholding (`bool`, default `False`):
|
||||||
|
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||||
|
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||||
|
stable-diffusion).
|
||||||
|
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||||
|
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||||
|
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
|
||||||
|
sample_max_value (`float`, default `1.0`):
|
||||||
|
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||||
|
@ -128,6 +139,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
set_alpha_to_one: bool = True,
|
set_alpha_to_one: bool = True,
|
||||||
steps_offset: int = 0,
|
steps_offset: int = 0,
|
||||||
prediction_type: str = "epsilon",
|
prediction_type: str = "epsilon",
|
||||||
|
thresholding: bool = False,
|
||||||
|
dynamic_thresholding_ratio: float = 0.995,
|
||||||
|
clip_sample_range: float = 1.0,
|
||||||
|
sample_max_value: float = 1.0,
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||||
|
@ -184,6 +199,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
|
dynamic_max_val = (
|
||||||
|
sample.flatten(1)
|
||||||
|
.abs()
|
||||||
|
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||||
|
.clamp_min(self.config.sample_max_value)
|
||||||
|
.view(-1, *([1] * (sample.ndim - 1)))
|
||||||
|
)
|
||||||
|
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||||
"""
|
"""
|
||||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||||
|
@ -286,9 +313,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
" `v_prediction`"
|
" `v_prediction`"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Clip "predicted x_0"
|
# 4. Clip or threshold "predicted x_0"
|
||||||
if self.config.clip_sample:
|
if self.config.clip_sample:
|
||||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
pred_original_sample = pred_original_sample.clamp(
|
||||||
|
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.config.thresholding:
|
||||||
|
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||||
|
|
||||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||||
|
|
|
@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||||
clip_sample (`bool`, default `True`):
|
clip_sample (`bool`, default `True`):
|
||||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
option to clip predicted sample for numerical stability.
|
||||||
|
clip_sample_range (`float`, default `1.0`):
|
||||||
|
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||||
prediction_type (`str`, default `epsilon`, optional):
|
prediction_type (`str`, default `epsilon`, optional):
|
||||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||||
https://imagen.research.google/video/paper.pdf)
|
https://imagen.research.google/video/paper.pdf)
|
||||||
|
thresholding (`bool`, default `False`):
|
||||||
|
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||||
|
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||||
|
stable-diffusion).
|
||||||
|
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||||
|
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||||
|
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
|
||||||
|
sample_max_value (`float`, default `1.0`):
|
||||||
|
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||||
|
@ -119,7 +130,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
variance_type: str = "fixed_small",
|
variance_type: str = "fixed_small",
|
||||||
clip_sample: bool = True,
|
clip_sample: bool = True,
|
||||||
prediction_type: str = "epsilon",
|
prediction_type: str = "epsilon",
|
||||||
clip_sample_range: Optional[float] = 1.0,
|
thresholding: bool = False,
|
||||||
|
dynamic_thresholding_ratio: float = 0.995,
|
||||||
|
clip_sample_range: float = 1.0,
|
||||||
|
sample_max_value: float = 1.0,
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||||
|
@ -226,6 +240,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
|
dynamic_max_val = (
|
||||||
|
sample.flatten(1)
|
||||||
|
.abs()
|
||||||
|
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||||
|
.clamp_min(self.config.sample_max_value)
|
||||||
|
.view(-1, *([1] * (sample.ndim - 1)))
|
||||||
|
)
|
||||||
|
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||||
|
|
||||||
def step(
|
def step(
|
||||||
self,
|
self,
|
||||||
model_output: torch.FloatTensor,
|
model_output: torch.FloatTensor,
|
||||||
|
@ -283,12 +308,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
" `v_prediction` for the DDPMScheduler."
|
" `v_prediction` for the DDPMScheduler."
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. Clip "predicted x_0"
|
# 3. Clip or threshold "predicted x_0"
|
||||||
if self.config.clip_sample:
|
if self.config.clip_sample:
|
||||||
pred_original_sample = torch.clamp(
|
pred_original_sample = pred_original_sample.clamp(
|
||||||
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
|
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.thresholding:
|
||||||
|
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||||
|
|
||||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
||||||
|
|
|
@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||||
(https://arxiv.org/abs/2205.11487).
|
(https://arxiv.org/abs/2205.11487).
|
||||||
sample_max_value (`float`, default `1.0`):
|
sample_max_value (`float`, default `1.0`):
|
||||||
the threshold value for dynamic thresholding. Valid woks when `thresholding=True`
|
the threshold value for dynamic thresholding. Valid only when `thresholding=True`
|
||||||
algorithm_type (`str`, default `deis`):
|
algorithm_type (`str`, default `deis`):
|
||||||
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
|
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
|
||||||
the future
|
the future
|
||||||
|
@ -194,6 +194,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
] * self.config.solver_order
|
] * self.config.solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
|
dynamic_max_val = (
|
||||||
|
sample.flatten(1)
|
||||||
|
.abs()
|
||||||
|
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||||
|
.clamp_min(self.config.sample_max_value)
|
||||||
|
.view(-1, *([1] * (sample.ndim - 1)))
|
||||||
|
)
|
||||||
|
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||||
|
|
||||||
def convert_model_output(
|
def convert_model_output(
|
||||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
@ -228,15 +240,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
orig_dtype = x0_pred.dtype
|
orig_dtype = x0_pred.dtype
|
||||||
if orig_dtype not in [torch.float, torch.double]:
|
if orig_dtype not in [torch.float, torch.double]:
|
||||||
x0_pred = x0_pred.float()
|
x0_pred = x0_pred.float()
|
||||||
dynamic_max_val = torch.quantile(
|
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
|
||||||
)
|
|
||||||
dynamic_max_val = torch.maximum(
|
|
||||||
dynamic_max_val,
|
|
||||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
|
||||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
|
||||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
|
||||||
x0_pred = x0_pred.type(orig_dtype)
|
|
||||||
|
|
||||||
if self.config.algorithm_type == "deis":
|
if self.config.algorithm_type == "deis":
|
||||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||||
|
|
|
@ -204,6 +204,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
] * self.config.solver_order
|
] * self.config.solver_order
|
||||||
self.lower_order_nums = 0
|
self.lower_order_nums = 0
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
|
dynamic_max_val = (
|
||||||
|
sample.flatten(1)
|
||||||
|
.abs()
|
||||||
|
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||||
|
.clamp_min(self.config.sample_max_value)
|
||||||
|
.view(-1, *([1] * (sample.ndim - 1)))
|
||||||
|
)
|
||||||
|
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||||
|
|
||||||
def convert_model_output(
|
def convert_model_output(
|
||||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
@ -247,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
orig_dtype = x0_pred.dtype
|
orig_dtype = x0_pred.dtype
|
||||||
if orig_dtype not in [torch.float, torch.double]:
|
if orig_dtype not in [torch.float, torch.double]:
|
||||||
x0_pred = x0_pred.float()
|
x0_pred = x0_pred.float()
|
||||||
dynamic_max_val = torch.quantile(
|
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
|
||||||
)
|
|
||||||
dynamic_max_val = torch.maximum(
|
|
||||||
dynamic_max_val,
|
|
||||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
|
||||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
|
||||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
|
||||||
x0_pred = x0_pred.type(orig_dtype)
|
|
||||||
return x0_pred
|
return x0_pred
|
||||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||||
elif self.config.algorithm_type == "dpmsolver":
|
elif self.config.algorithm_type == "dpmsolver":
|
||||||
|
|
|
@ -237,6 +237,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
self.sample = None
|
self.sample = None
|
||||||
self.orders = self.get_order_list(num_inference_steps)
|
self.orders = self.get_order_list(num_inference_steps)
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
|
dynamic_max_val = (
|
||||||
|
sample.flatten(1)
|
||||||
|
.abs()
|
||||||
|
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||||
|
.clamp_min(self.config.sample_max_value)
|
||||||
|
.view(-1, *([1] * (sample.ndim - 1)))
|
||||||
|
)
|
||||||
|
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||||
|
|
||||||
def convert_model_output(
|
def convert_model_output(
|
||||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
@ -277,18 +289,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
if self.config.thresholding:
|
if self.config.thresholding:
|
||||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
dtype = x0_pred.dtype
|
orig_dtype = x0_pred.dtype
|
||||||
dynamic_max_val = torch.quantile(
|
if orig_dtype not in [torch.float, torch.double]:
|
||||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(),
|
x0_pred = x0_pred.float()
|
||||||
self.config.dynamic_thresholding_ratio,
|
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||||
dim=1,
|
|
||||||
)
|
|
||||||
dynamic_max_val = torch.maximum(
|
|
||||||
dynamic_max_val,
|
|
||||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
|
||||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
|
||||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
|
||||||
x0_pred = x0_pred.to(dtype)
|
|
||||||
return x0_pred
|
return x0_pred
|
||||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||||
elif self.config.algorithm_type == "dpmsolver":
|
elif self.config.algorithm_type == "dpmsolver":
|
||||||
|
|
|
@ -109,7 +109,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
Args:
|
Args:
|
||||||
num_inference_steps (`int`):
|
num_inference_steps (`int`):
|
||||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
sampling_eps (`float`, optional):
|
||||||
|
final timestep value (overrides value given at Scheduler instantiation).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||||
|
@ -129,8 +130,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
sigma_min (`float`, optional):
|
sigma_min (`float`, optional):
|
||||||
initial noise scale value (overrides value given at Scheduler instantiation).
|
initial noise scale value (overrides value given at Scheduler instantiation).
|
||||||
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
|
sigma_max (`float`, optional):
|
||||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
final noise scale value (overrides value given at Scheduler instantiation).
|
||||||
|
sampling_eps (`float`, optional):
|
||||||
|
final timestep value (overrides value given at Scheduler instantiation).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||||
|
|
|
@ -116,7 +116,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
|
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
|
||||||
num_inference_steps (`int`):
|
num_inference_steps (`int`):
|
||||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
sampling_eps (`float`, optional):
|
||||||
|
final timestep value (overrides value given at Scheduler instantiation).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||||
|
@ -143,8 +144,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||||
sigma_min (`float`, optional):
|
sigma_min (`float`, optional):
|
||||||
initial noise scale value (overrides value given at Scheduler instantiation).
|
initial noise scale value (overrides value given at Scheduler instantiation).
|
||||||
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
|
sigma_max (`float`, optional):
|
||||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
final noise scale value (overrides value given at Scheduler instantiation).
|
||||||
|
sampling_eps (`float`, optional):
|
||||||
|
final timestep value (overrides value given at Scheduler instantiation).
|
||||||
"""
|
"""
|
||||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||||
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
|
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
|
||||||
|
|
|
@ -210,6 +210,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
if self.solver_p:
|
if self.solver_p:
|
||||||
self.solver_p.set_timesteps(num_inference_steps, device=device)
|
self.solver_p.set_timesteps(num_inference_steps, device=device)
|
||||||
|
|
||||||
|
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||||
|
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||||
|
dynamic_max_val = (
|
||||||
|
sample.flatten(1)
|
||||||
|
.abs()
|
||||||
|
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||||
|
.clamp_min(self.config.sample_max_value)
|
||||||
|
.view(-1, *([1] * (sample.ndim - 1)))
|
||||||
|
)
|
||||||
|
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||||
|
|
||||||
def convert_model_output(
|
def convert_model_output(
|
||||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
|
@ -245,15 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||||
orig_dtype = x0_pred.dtype
|
orig_dtype = x0_pred.dtype
|
||||||
if orig_dtype not in [torch.float, torch.double]:
|
if orig_dtype not in [torch.float, torch.double]:
|
||||||
x0_pred = x0_pred.float()
|
x0_pred = x0_pred.float()
|
||||||
dynamic_max_val = torch.quantile(
|
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
|
||||||
)
|
|
||||||
dynamic_max_val = torch.maximum(
|
|
||||||
dynamic_max_val,
|
|
||||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
|
||||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
|
||||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
|
||||||
x0_pred = x0_pred.type(orig_dtype)
|
|
||||||
return x0_pred
|
return x0_pred
|
||||||
else:
|
else:
|
||||||
if self.config.prediction_type == "epsilon":
|
if self.config.prediction_type == "epsilon":
|
||||||
|
|
|
@ -647,6 +647,16 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
||||||
for clip_sample in [True, False]:
|
for clip_sample in [True, False]:
|
||||||
self.check_over_configs(clip_sample=clip_sample)
|
self.check_over_configs(clip_sample=clip_sample)
|
||||||
|
|
||||||
|
def test_thresholding(self):
|
||||||
|
self.check_over_configs(thresholding=False)
|
||||||
|
for threshold in [0.5, 1.0, 2.0]:
|
||||||
|
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||||
|
self.check_over_configs(
|
||||||
|
thresholding=True,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
sample_max_value=threshold,
|
||||||
|
)
|
||||||
|
|
||||||
def test_prediction_type(self):
|
def test_prediction_type(self):
|
||||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||||
self.check_over_configs(prediction_type=prediction_type)
|
self.check_over_configs(prediction_type=prediction_type)
|
||||||
|
@ -791,6 +801,16 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
||||||
for clip_sample in [True, False]:
|
for clip_sample in [True, False]:
|
||||||
self.check_over_configs(clip_sample=clip_sample)
|
self.check_over_configs(clip_sample=clip_sample)
|
||||||
|
|
||||||
|
def test_thresholding(self):
|
||||||
|
self.check_over_configs(thresholding=False)
|
||||||
|
for threshold in [0.5, 1.0, 2.0]:
|
||||||
|
for prediction_type in ["epsilon", "v_prediction"]:
|
||||||
|
self.check_over_configs(
|
||||||
|
thresholding=True,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
sample_max_value=threshold,
|
||||||
|
)
|
||||||
|
|
||||||
def test_time_indices(self):
|
def test_time_indices(self):
|
||||||
for t in [1, 10, 49]:
|
for t in [1, 10, 49]:
|
||||||
self.check_over_forward(time_step=t)
|
self.check_over_forward(time_step=t)
|
||||||
|
@ -1212,6 +1232,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
||||||
|
|
||||||
assert abs(result_mean.item() - 0.3301) < 1e-3
|
assert abs(result_mean.item() - 0.3301) < 1e-3
|
||||||
|
|
||||||
|
def test_full_loop_no_noise_thres(self):
|
||||||
|
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
|
||||||
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
||||||
|
assert abs(result_mean.item() - 0.6405) < 1e-3
|
||||||
|
|
||||||
def test_full_loop_with_v_prediction(self):
|
def test_full_loop_with_v_prediction(self):
|
||||||
sample = self.full_loop(prediction_type="v_prediction")
|
sample = self.full_loop(prediction_type="v_prediction")
|
||||||
result_mean = torch.mean(torch.abs(sample))
|
result_mean = torch.mean(torch.abs(sample))
|
||||||
|
|
Loading…
Reference in New Issue