[Bug] scheduling_ddpm: fix variance in the case of learned_range type. (#2090)
scheduling_ddpm: fix variance in the case of learned_range type. In the case of learned_range variance type, there are missing logs and exponent comparing to the theory (see "Improved Denoising Diffusion Probabilistic Models" section 3.1 equation 15: https://arxiv.org/pdf/2102.09672.pdf).
This commit is contained in:
parent
2bbd532990
commit
ecadcdefe1
|
@ -218,8 +218,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
elif variance_type == "learned":
|
elif variance_type == "learned":
|
||||||
return predicted_variance
|
return predicted_variance
|
||||||
elif variance_type == "learned_range":
|
elif variance_type == "learned_range":
|
||||||
min_log = variance
|
min_log = torch.log(variance)
|
||||||
max_log = self.betas[t]
|
max_log = torch.log(self.betas[t])
|
||||||
frac = (predicted_variance + 1) / 2
|
frac = (predicted_variance + 1) / 2
|
||||||
variance = frac * max_log + (1 - frac) * min_log
|
variance = frac * max_log + (1 - frac) * min_log
|
||||||
|
|
||||||
|
@ -304,6 +304,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
)
|
)
|
||||||
if self.variance_type == "fixed_small_log":
|
if self.variance_type == "fixed_small_log":
|
||||||
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise
|
||||||
|
elif self.variance_type == "learned_range":
|
||||||
|
variance = self._get_variance(t, predicted_variance=predicted_variance)
|
||||||
|
variance = torch.exp(0.5 * variance) * variance_noise
|
||||||
else:
|
else:
|
||||||
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue