Improve ddim scheduler and fix bug when prediction type is "sample" (#2094)
Improve ddim scheduler Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
c5f6c538fd
commit
c812d97d5b
|
@ -301,12 +301,13 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
if self.config.prediction_type == "epsilon":
|
if self.config.prediction_type == "epsilon":
|
||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
pred_epsilon = model_output
|
||||||
elif self.config.prediction_type == "sample":
|
elif self.config.prediction_type == "sample":
|
||||||
pred_original_sample = model_output
|
pred_original_sample = model_output
|
||||||
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||||
elif self.config.prediction_type == "v_prediction":
|
elif self.config.prediction_type == "v_prediction":
|
||||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||||
# predict V
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||||
|
@ -328,17 +329,16 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
std_dev_t = eta * variance ** (0.5)
|
std_dev_t = eta * variance ** (0.5)
|
||||||
|
|
||||||
if use_clipped_model_output:
|
if use_clipped_model_output:
|
||||||
# the model_output is always re-derived from the clipped x_0 in Glide
|
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
||||||
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||||
|
|
||||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
||||||
|
|
||||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||||
|
|
||||||
if eta > 0:
|
if eta > 0:
|
||||||
device = model_output.device
|
|
||||||
if variance_noise is not None and generator is not None:
|
if variance_noise is not None and generator is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
||||||
|
@ -347,7 +347,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
if variance_noise is None:
|
if variance_noise is None:
|
||||||
variance_noise = randn_tensor(
|
variance_noise = randn_tensor(
|
||||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||||
)
|
)
|
||||||
variance = std_dev_t * variance_noise
|
variance = std_dev_t * variance_noise
|
||||||
|
|
||||||
|
|
|
@ -254,12 +254,13 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
if self.config.prediction_type == "epsilon":
|
if self.config.prediction_type == "epsilon":
|
||||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||||
|
pred_epsilon = model_output
|
||||||
elif self.config.prediction_type == "sample":
|
elif self.config.prediction_type == "sample":
|
||||||
pred_original_sample = model_output
|
pred_original_sample = model_output
|
||||||
|
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||||
elif self.config.prediction_type == "v_prediction":
|
elif self.config.prediction_type == "v_prediction":
|
||||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||||
# predict V
|
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||||
|
@ -272,7 +273,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||||
std_dev_t = eta * variance ** (0.5)
|
std_dev_t = eta * variance ** (0.5)
|
||||||
|
|
||||||
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
||||||
|
|
||||||
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
# 6. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||||
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||||
|
|
Loading…
Reference in New Issue