fix heun scheduler (#1512)
This commit is contained in:
parent
e65b71aba4
commit
0f1c24664c
|
@ -186,10 +186,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
|
||||
pred_original_sample = sample - sigma_input * model_output
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
|
||||
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
|
||||
sample / (sigma_input**2 + 1)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
|
@ -207,7 +210,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.sample = sample
|
||||
else:
|
||||
# 2. 2nd order / Heun's method
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
derivative = (sample - pred_original_sample) / sigma_next
|
||||
derivative = (self.prev_derivative + derivative) / 2
|
||||
|
||||
# 3. Retrieve 1st order derivative
|
||||
|
|
Loading…
Reference in New Issue