From 0f1c24664c5dabbb0e2fb11115a4afe54dbc7c34 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Thu, 1 Dec 2022 22:39:57 +0100 Subject: [PATCH] fix heun scheduler (#1512) --- src/diffusers/schedulers/scheduling_heun.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_heun.py b/src/diffusers/schedulers/scheduling_heun.py index 27a54f64..b7bf7ca3 100644 --- a/src/diffusers/schedulers/scheduling_heun.py +++ b/src/diffusers/schedulers/scheduling_heun.py @@ -186,10 +186,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise if self.config.prediction_type == "epsilon": - pred_original_sample = sample - sigma_hat * model_output + sigma_input = sigma_hat if self.state_in_first_order else sigma_next + pred_original_sample = sample - sigma_input * model_output elif self.config.prediction_type == "v_prediction": - # * c_out + input * c_skip - pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + sigma_input = sigma_hat if self.state_in_first_order else sigma_next + pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + ( + sample / (sigma_input**2 + 1) + ) else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" @@ -207,7 +210,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): self.sample = sample else: # 2. 2nd order / Heun's method - derivative = (sample - pred_original_sample) / sigma_hat + derivative = (sample - pred_original_sample) / sigma_next derivative = (self.prev_derivative + derivative) / 2 # 3. Retrieve 1st order derivative