add v prediction (#1386)
* add v prediction * adat euler for v pred * velocity -> v_prediction * simplify * fix naming * Update src/diffusers/schedulers/scheduling_euler_discrete.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * style Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
9f476388fa
commit
30f6f44104
|
@ -122,6 +122,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
clip_sample: bool = True,
|
||||
set_alpha_to_one: bool = True,
|
||||
steps_offset: int = 0,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
|
@ -138,6 +139,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
|
@ -258,7 +261,19 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
if self.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
elif self.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
# predict V
|
||||
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
|
|
|
@ -78,6 +78,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
if trained_betas is not None:
|
||||
self.betas = torch.from_numpy(trained_betas)
|
||||
|
@ -91,6 +92,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
else:
|
||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||
|
||||
self.prediction_type = prediction_type
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
|
||||
|
@ -229,7 +232,15 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
if self.prediction_type == "epsilon":
|
||||
pred_original_sample = sample - sigma_hat * model_output
|
||||
elif self.prediction_type == "v_prediction":
|
||||
# * c_out + input * c_skip
|
||||
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
||||
)
|
||||
|
||||
# 2. Convert to an ODE derivative
|
||||
derivative = (sample - pred_original_sample) / sigma_hat
|
||||
|
|
Loading…
Reference in New Issue