fix alphas_cumprod
This commit is contained in:
parent
4497e78d00
commit
e1ef122260
|
@ -137,8 +137,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
return pred_prev_sample
|
||||
|
||||
def forward_step(self, original_sample, noise, t):
|
||||
sqrt_alpha_prod = self.alpha_prod_t[t] ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alpha_prod_t[t]) ** 0.5
|
||||
sqrt_alpha_prod = self.alphas_cumprod[t] ** 0.5
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[t]) ** 0.5
|
||||
noisy_sample = sqrt_alpha_prod * original_sample + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_sample
|
||||
|
||||
|
|
Loading…
Reference in New Issue