add v prediction (#1386)

* add v prediction

* adat euler for v pred

* velocity -> v_prediction

* simplify

* fix naming

* Update src/diffusers/schedulers/scheduling_euler_discrete.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* style

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Suraj Patil 2022-11-24 12:25:19 +01:00 committed by GitHub
parent 9f476388fa
commit 30f6f44104
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 2 deletions

View File

@ -122,6 +122,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
prediction_type: str = "epsilon",
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
@ -138,6 +139,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@ -258,7 +261,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
if self.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.prediction_type == "sample":
pred_original_sample = model_output
elif self.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# predict V
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip "predicted x_0"
if self.config.clip_sample:

View File

@ -78,6 +78,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
prediction_type: str = "epsilon",
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
@ -91,6 +92,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.prediction_type = prediction_type
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
@ -229,7 +232,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma_hat * model_output
if self.prediction_type == "epsilon":
pred_original_sample = sample - sigma_hat * model_output
elif self.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat