fix heun scheduler (#1512)

This commit is contained in:
Suraj Patil 2022-12-01 22:39:57 +01:00 committed by GitHub
parent e65b71aba4
commit 0f1c24664c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 4 deletions

View File

@ -186,10 +186,13 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon": if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output sigma_input = sigma_hat if self.state_in_first_order else sigma_next
pred_original_sample = sample - sigma_input * model_output
elif self.config.prediction_type == "v_prediction": elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip sigma_input = sigma_hat if self.state_in_first_order else sigma_next
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) pred_original_sample = model_output * (-sigma_input / (sigma_input**2 + 1) ** 0.5) + (
sample / (sigma_input**2 + 1)
)
else: else:
raise ValueError( raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`" f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
@ -207,7 +210,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.sample = sample self.sample = sample
else: else:
# 2. 2nd order / Heun's method # 2. 2nd order / Heun's method
derivative = (sample - pred_original_sample) / sigma_hat derivative = (sample - pred_original_sample) / sigma_next
derivative = (self.prev_derivative + derivative) / 2 derivative = (self.prev_derivative + derivative) / 2
# 3. Retrieve 1st order derivative # 3. Retrieve 1st order derivative