parent
e5eed5235b
commit
3304538229
|
@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
if self.alphas_cumprod.device != original_samples.device:
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||
|
||||
if timesteps.device != original_samples.device:
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
|
|
|
@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
) -> torch.FloatTensor:
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
if self.alphas_cumprod.device != original_samples.device:
|
||||
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||
|
||||
if timesteps.device != original_samples.device:
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
|
@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod.flatten()
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
|
|
|
@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
if timesteps.device != original_samples.device:
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
|
|
Loading…
Reference in New Issue