fix heun scheduler (#1512)
This commit is contained in:
parent
e65b71aba4
commit
0f1c24664c
|
@ -186,10 +186,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||||
if self.config.prediction_type == "epsilon":
|
if self.config.prediction_type == "epsilon":
|
||||||
pred_original_sample = sample - sigma_hat * model_output
|
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
|
||||||
|
pred_original_sample = sample - sigma_input * model_output
|
||||||
elif self.config.prediction_type == "v_prediction":
|
elif self.config.prediction_type == "v_prediction":
|
||||||
# * c_out + input * c_skip
|
sigma_input = sigma_hat if self.state_in_first_order else sigma_next
|
||||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
|
||||||
|
sample / (sigma_input**2 + 1)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||||
|
@ -207,7 +210,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
self.sample = sample
|
self.sample = sample
|
||||||
else:
|
else:
|
||||||
# 2. 2nd order / Heun's method
|
# 2. 2nd order / Heun's method
|
||||||
derivative = (sample - pred_original_sample) / sigma_hat
|
derivative = (sample - pred_original_sample) / sigma_next
|
||||||
derivative = (self.prev_derivative + derivative) / 2
|
derivative = (self.prev_derivative + derivative) / 2
|
||||||
|
|
||||||
# 3. Retrieve 1st order derivative
|
# 3. Retrieve 1st order derivative
|
||||||
|
|
Loading…
Reference in New Issue