Improve dynamic thresholding and extend to DDPM and DDIM Schedulers (#2528)

* Improve dynamic threshold

* Update code

* Add dynamic threshold to ddim and ddpm

* Encapsulate and leverage code copy mechanism

Update style

* Clean up DDPM/DDIM constructor arguments

* add test

* also add to unipc

---------

Co-authored-by: Peter Lin <peterlin9863@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
clarencechen 2023-03-07 14:10:26 -08:00 committed by GitHub
parent 46bef6e31d
commit 55660cfb6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 171 additions and 60 deletions

View File

@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization. norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately class_embed_type (`str`, *optional*, defaults to None):
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`. The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, or `"identity"`.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.

View File

@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately class_embed_type (`str`, *optional*, defaults to None):
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.

View File

@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`. for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately class_embed_type (`str`, *optional*, defaults to None):
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, or `"projection"`.
num_class_embeds (`int`, *optional*, defaults to None): num_class_embeds (`int`, *optional*, defaults to None):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`. class conditioning with `class_embed_type` equal to `None`.

View File

@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
trained_betas (`np.ndarray`, optional): trained_betas (`np.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
set_alpha_to_one (`bool`, default `True`): set_alpha_to_one (`bool`, default `True`):
each diffusion step uses the value of alphas product at that step and at the previous one. For the final each diffusion step uses the value of alphas product at that step and at the previous one. For the final
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@ -128,6 +139,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
set_alpha_to_one: bool = True, set_alpha_to_one: bool = True,
steps_offset: int = 0, steps_offset: int = 0,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@ -184,6 +199,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@ -286,9 +313,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
" `v_prediction`" " `v_prediction`"
) )
# 4. Clip "predicted x_0" # 4. Clip or threshold "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
pred_original_sample = torch.clamp(pred_original_sample, -1, 1) pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 5. compute variance: "sigma_t(η)" -> see formula (16) # 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1) # σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)

View File

@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`, options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`. `fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
clip_sample (`bool`, default `True`): clip_sample (`bool`, default `True`):
option to clip predicted sample between -1 and 1 for numerical stability. option to clip predicted sample for numerical stability.
clip_sample_range (`float`, default `1.0`):
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
prediction_type (`str`, default `epsilon`, optional): prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
thresholding (`bool`, default `False`):
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
stable-diffusion).
dynamic_thresholding_ratio (`float`, default `0.995`):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
""" """
_compatibles = [e.name for e in KarrasDiffusionSchedulers] _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@ -119,7 +130,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
variance_type: str = "fixed_small", variance_type: str = "fixed_small",
clip_sample: bool = True, clip_sample: bool = True,
prediction_type: str = "epsilon", prediction_type: str = "epsilon",
clip_sample_range: Optional[float] = 1.0, thresholding: bool = False,
dynamic_thresholding_ratio: float = 0.995,
clip_sample_range: float = 1.0,
sample_max_value: float = 1.0,
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = torch.tensor(trained_betas, dtype=torch.float32) self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@ -226,6 +240,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def step( def step(
self, self,
model_output: torch.FloatTensor, model_output: torch.FloatTensor,
@ -283,12 +308,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
" `v_prediction` for the DDPMScheduler." " `v_prediction` for the DDPMScheduler."
) )
# 3. Clip "predicted x_0" # 3. Clip or threshold "predicted x_0"
if self.config.clip_sample: if self.config.clip_sample:
pred_original_sample = torch.clamp( pred_original_sample = pred_original_sample.clamp(
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range -self.config.clip_sample_range, self.config.clip_sample_range
) )
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t

View File

@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
(https://arxiv.org/abs/2205.11487). (https://arxiv.org/abs/2205.11487).
sample_max_value (`float`, default `1.0`): sample_max_value (`float`, default `1.0`):
the threshold value for dynamic thresholding. Valid woks when `thresholding=True` the threshold value for dynamic thresholding. Valid only when `thresholding=True`
algorithm_type (`str`, default `deis`): algorithm_type (`str`, default `deis`):
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
the future the future
@ -194,6 +194,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -228,15 +240,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]: if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float() x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile( x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
if self.config.algorithm_type == "deis": if self.config.algorithm_type == "deis":
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]

View File

@ -204,6 +204,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
] * self.config.solver_order ] * self.config.solver_order
self.lower_order_nums = 0 self.lower_order_nums = 0
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -247,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]: if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float() x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile( x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":

View File

@ -237,6 +237,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
self.sample = None self.sample = None
self.orders = self.get_order_list(num_inference_steps) self.orders = self.get_order_list(num_inference_steps)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -277,18 +289,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
if self.config.thresholding: if self.config.thresholding:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487 # Dynamic thresholding in https://arxiv.org/abs/2205.11487
dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
dynamic_max_val = torch.quantile( if orig_dtype not in [torch.float, torch.double]:
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(), x0_pred = x0_pred.float()
self.config.dynamic_thresholding_ratio, x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
dim=1,
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.to(dtype)
return x0_pred return x0_pred
# DPM-Solver needs to solve an integral of the noise prediction model. # DPM-Solver needs to solve an integral of the noise prediction model.
elif self.config.algorithm_type == "dpmsolver": elif self.config.algorithm_type == "dpmsolver":

View File

@ -109,7 +109,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
Args: Args:
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
@ -129,8 +130,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional): sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation). initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional):
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min

View File

@ -116,7 +116,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance. state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
num_inference_steps (`int`): num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
@ -143,8 +144,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model. the number of diffusion steps used when generating samples with a pre-trained model.
sigma_min (`float`, optional): sigma_min (`float`, optional):
initial noise scale value (overrides value given at Scheduler instantiation). initial noise scale value (overrides value given at Scheduler instantiation).
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation). sigma_max (`float`, optional):
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation). final noise scale value (overrides value given at Scheduler instantiation).
sampling_eps (`float`, optional):
final timestep value (overrides value given at Scheduler instantiation).
""" """
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max

View File

@ -210,6 +210,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.solver_p: if self.solver_p:
self.solver_p.set_timesteps(num_inference_steps, device=device) self.solver_p.set_timesteps(num_inference_steps, device=device)
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
dynamic_max_val = (
sample.flatten(1)
.abs()
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
.clamp_min(self.config.sample_max_value)
.view(-1, *([1] * (sample.ndim - 1)))
)
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
def convert_model_output( def convert_model_output(
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -245,15 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
orig_dtype = x0_pred.dtype orig_dtype = x0_pred.dtype
if orig_dtype not in [torch.float, torch.double]: if orig_dtype not in [torch.float, torch.double]:
x0_pred = x0_pred.float() x0_pred = x0_pred.float()
dynamic_max_val = torch.quantile( x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
)
dynamic_max_val = torch.maximum(
dynamic_max_val,
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
)[(...,) + (None,) * (x0_pred.ndim - 1)]
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
x0_pred = x0_pred.type(orig_dtype)
return x0_pred return x0_pred
else: else:
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":

View File

@ -647,6 +647,16 @@ class DDPMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_prediction_type(self): def test_prediction_type(self):
for prediction_type in ["epsilon", "sample", "v_prediction"]: for prediction_type in ["epsilon", "sample", "v_prediction"]:
self.check_over_configs(prediction_type=prediction_type) self.check_over_configs(prediction_type=prediction_type)
@ -791,6 +801,16 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for clip_sample in [True, False]: for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample) self.check_over_configs(clip_sample=clip_sample)
def test_thresholding(self):
self.check_over_configs(thresholding=False)
for threshold in [0.5, 1.0, 2.0]:
for prediction_type in ["epsilon", "v_prediction"]:
self.check_over_configs(
thresholding=True,
prediction_type=prediction_type,
sample_max_value=threshold,
)
def test_time_indices(self): def test_time_indices(self):
for t in [1, 10, 49]: for t in [1, 10, 49]:
self.check_over_forward(time_step=t) self.check_over_forward(time_step=t)
@ -1212,6 +1232,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
assert abs(result_mean.item() - 0.3301) < 1e-3 assert abs(result_mean.item() - 0.3301) < 1e-3
def test_full_loop_no_noise_thres(self):
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 0.6405) < 1e-3
def test_full_loop_with_v_prediction(self): def test_full_loop_with_v_prediction(self):
sample = self.full_loop(prediction_type="v_prediction") sample = self.full_loop(prediction_type="v_prediction")
result_mean = torch.mean(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample))