[DDIM, DDPM] fix add_noise (#648)

fix add noise
This commit is contained in:
Suraj Patil 2022-09-27 17:32:43 +02:00 committed by GitHub
parent e5eed5235b
commit 3304538229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 5 deletions

View File

@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):

View File

@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
timesteps = timesteps.to(self.alphas_cumprod.device)
if self.alphas_cumprod.device != original_samples.device:
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
sqrt_one_minus_alpha_prod.flatten()
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

View File

@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
if timesteps.device != original_samples.device:
timesteps = timesteps.to(original_samples.device)
timesteps = timesteps.to(self.alphas_cumprod.device)
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):