Improve dynamic thresholding and extend to DDPM and DDIM Schedulers (#2528)
* Improve dynamic threshold * Update code * Add dynamic threshold to ddim and ddpm * Encapsulate and leverage code copy mechanism Update style * Clean up DDPM/DDIM constructor arguments * add test * also add to unipc --------- Co-authored-by: Peter Lin <peterlin9863@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
46bef6e31d
commit
55660cfb6d
|
@ -70,8 +70,9 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, or `"identity"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
|
|
|
@ -90,8 +90,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, or `"projection"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
|
|
|
@ -171,8 +171,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
|
||||
class_embed_type (`str`, *optional*, defaults to None):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
||||
`"timestep"`, `"identity"`, or `"projection"`.
|
||||
num_class_embeds (`int`, *optional*, defaults to None):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
|
|
|
@ -98,7 +98,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
trained_betas (`np.ndarray`, optional):
|
||||
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
set_alpha_to_one (`bool`, default `True`):
|
||||
each diffusion step uses the value of alphas product at that step and at the previous one. For the final
|
||||
step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
|
||||
|
@ -111,6 +113,15 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
|
@ -128,6 +139,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
|
@ -184,6 +199,18 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return variance
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
@ -286,9 +313,14 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
# 4. Clip or threshold "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
|
|
|
@ -98,11 +98,22 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small`,
|
||||
`fixed_small_log`, `fixed_large`, `fixed_large_log`, `learned` or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between -1 and 1 for numerical stability.
|
||||
option to clip predicted sample for numerical stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
|
||||
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
thresholding (`bool`, default `False`):
|
||||
whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
|
||||
Note that the thresholding method is unsuitable for latent-space diffusion models (such as
|
||||
stable-diffusion).
|
||||
dynamic_thresholding_ratio (`float`, default `0.995`):
|
||||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
|
||||
"""
|
||||
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
|
@ -119,7 +130,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
variance_type: str = "fixed_small",
|
||||
clip_sample: bool = True,
|
||||
prediction_type: str = "epsilon",
|
||||
clip_sample_range: Optional[float] = 1.0,
|
||||
thresholding: bool = False,
|
||||
dynamic_thresholding_ratio: float = 0.995,
|
||||
clip_sample_range: float = 1.0,
|
||||
sample_max_value: float = 1.0,
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
|
||||
|
@ -226,6 +240,17 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return variance
|
||||
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
|
@ -283,12 +308,15 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
" `v_prediction` for the DDPMScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
# 3. Clip or threshold "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(
|
||||
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
|
||||
|
|
|
@ -96,7 +96,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
|
||||
(https://arxiv.org/abs/2205.11487).
|
||||
sample_max_value (`float`, default `1.0`):
|
||||
the threshold value for dynamic thresholding. Valid woks when `thresholding=True`
|
||||
the threshold value for dynamic thresholding. Valid only when `thresholding=True`
|
||||
algorithm_type (`str`, default `deis`):
|
||||
the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in
|
||||
the future
|
||||
|
@ -194,6 +194,18 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
|
@ -228,15 +240,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
|
||||
if self.config.algorithm_type == "deis":
|
||||
alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep]
|
||||
|
|
|
@ -204,6 +204,18 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
] * self.config.solver_order
|
||||
self.lower_order_nums = 0
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
|
@ -247,15 +259,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
|
|
|
@ -237,6 +237,18 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.sample = None
|
||||
self.orders = self.get_order_list(num_inference_steps)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
|
@ -277,18 +289,10 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
if self.config.thresholding:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dtype = x0_pred.dtype
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)).float(),
|
||||
self.config.dynamic_thresholding_ratio,
|
||||
dim=1,
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.to(dtype)
|
||||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
return x0_pred
|
||||
# DPM-Solver needs to solve an integral of the noise prediction model.
|
||||
elif self.config.algorithm_type == "dpmsolver":
|
||||
|
|
|
@ -109,7 +109,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
|
@ -129,8 +130,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sigma_min (`float`, optional):
|
||||
initial noise scale value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional):
|
||||
final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||
|
|
|
@ -116,7 +116,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
state (`ScoreSdeVeSchedulerState`): the `FlaxScoreSdeVeScheduler` state data class instance.
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
|
||||
"""
|
||||
sampling_eps = sampling_eps if sampling_eps is not None else self.config.sampling_eps
|
||||
|
@ -143,8 +144,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
|||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
sigma_min (`float`, optional):
|
||||
initial noise scale value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional): final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional): final timestep value (overrides value given at Scheduler instantiation).
|
||||
sigma_max (`float`, optional):
|
||||
final noise scale value (overrides value given at Scheduler instantiation).
|
||||
sampling_eps (`float`, optional):
|
||||
final timestep value (overrides value given at Scheduler instantiation).
|
||||
"""
|
||||
sigma_min = sigma_min if sigma_min is not None else self.config.sigma_min
|
||||
sigma_max = sigma_max if sigma_max is not None else self.config.sigma_max
|
||||
|
|
|
@ -210,6 +210,18 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
if self.solver_p:
|
||||
self.solver_p.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
||||
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
# Dynamic thresholding in https://arxiv.org/abs/2205.11487
|
||||
dynamic_max_val = (
|
||||
sample.flatten(1)
|
||||
.abs()
|
||||
.quantile(self.config.dynamic_thresholding_ratio, dim=1)
|
||||
.clamp_min(self.config.sample_max_value)
|
||||
.view(-1, *([1] * (sample.ndim - 1)))
|
||||
)
|
||||
return sample.clamp(-dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
|
||||
def convert_model_output(
|
||||
self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
|
@ -245,15 +257,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|||
orig_dtype = x0_pred.dtype
|
||||
if orig_dtype not in [torch.float, torch.double]:
|
||||
x0_pred = x0_pred.float()
|
||||
dynamic_max_val = torch.quantile(
|
||||
torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1
|
||||
)
|
||||
dynamic_max_val = torch.maximum(
|
||||
dynamic_max_val,
|
||||
self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device),
|
||||
)[(...,) + (None,) * (x0_pred.ndim - 1)]
|
||||
x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val
|
||||
x0_pred = x0_pred.type(orig_dtype)
|
||||
x0_pred = self._threshold_sample(x0_pred).type(orig_dtype)
|
||||
return x0_pred
|
||||
else:
|
||||
if self.config.prediction_type == "epsilon":
|
||||
|
|
|
@ -647,6 +647,16 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
|||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_thresholding(self):
|
||||
self.check_over_configs(thresholding=False)
|
||||
for threshold in [0.5, 1.0, 2.0]:
|
||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
self.check_over_configs(
|
||||
thresholding=True,
|
||||
prediction_type=prediction_type,
|
||||
sample_max_value=threshold,
|
||||
)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample", "v_prediction"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
@ -791,6 +801,16 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
|||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_thresholding(self):
|
||||
self.check_over_configs(thresholding=False)
|
||||
for threshold in [0.5, 1.0, 2.0]:
|
||||
for prediction_type in ["epsilon", "v_prediction"]:
|
||||
self.check_over_configs(
|
||||
thresholding=True,
|
||||
prediction_type=prediction_type,
|
||||
sample_max_value=threshold,
|
||||
)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [1, 10, 49]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
@ -1212,6 +1232,12 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
assert abs(result_mean.item() - 0.3301) < 1e-3
|
||||
|
||||
def test_full_loop_no_noise_thres(self):
|
||||
sample = self.full_loop(thresholding=True, dynamic_thresholding_ratio=0.87, sample_max_value=0.5)
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 0.6405) < 1e-3
|
||||
|
||||
def test_full_loop_with_v_prediction(self):
|
||||
sample = self.full_loop(prediction_type="v_prediction")
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
|
Loading…
Reference in New Issue