diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 6258933d..e27b793b 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -247,6 +247,9 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): if self.config.thresholding: # Dynamic thresholding in https://arxiv.org/abs/2205.11487 + orig_dtype = x0_pred.dtype + if orig_dtype not in [torch.float, torch.double]: + x0_pred = x0_pred.float() dynamic_max_val = torch.quantile( torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 ) @@ -255,6 +258,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), )[(...,) + (None,) * (x0_pred.ndim - 1)] x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val + x0_pred = x0_pred.type(orig_dtype) return x0_pred # DPM-Solver needs to solve an integral of the noise prediction model. elif self.config.algorithm_type == "dpmsolver": diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f90246b3..f840f8ce 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -991,6 +991,22 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): assert abs(result_mean.item() - 0.3301) < 1e-3 + def test_fp16_support(self): + scheduler_class = self.scheduler_classes[0] + scheduler_config = self.get_scheduler_config(thresholding=True, dynamic_thresholding_ratio=0) + scheduler = scheduler_class(**scheduler_config) + + num_inference_steps = 10 + model = self.dummy_model() + sample = self.dummy_sample_deter.half() + scheduler.set_timesteps(num_inference_steps) + + for i, t in enumerate(scheduler.timesteps): + residual = model(sample, t) + sample = scheduler.step(residual, t, sample).prev_sample + + assert sample.dtype == torch.float16 + class PNDMSchedulerTest(SchedulerCommonTest): scheduler_classes = (PNDMScheduler,)