parent
e5eed5235b
commit
3304538229
|
@ -282,7 +282,12 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
if self.alphas_cumprod.device != original_samples.device:
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||||
|
|
||||||
|
if timesteps.device != original_samples.device:
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||||
|
|
|
@ -268,7 +268,11 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
noise: torch.FloatTensor,
|
noise: torch.FloatTensor,
|
||||||
timesteps: torch.IntTensor,
|
timesteps: torch.IntTensor,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
if self.alphas_cumprod.device != original_samples.device:
|
||||||
|
self.alphas_cumprod = self.alphas_cumprod.to(original_samples.device)
|
||||||
|
|
||||||
|
if timesteps.device != original_samples.device:
|
||||||
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
|
@ -276,7 +280,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||||
sqrt_one_minus_alpha_prod.flatten()
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||||
|
|
||||||
|
|
|
@ -387,8 +387,6 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
if timesteps.device != original_samples.device:
|
if timesteps.device != original_samples.device:
|
||||||
timesteps = timesteps.to(original_samples.device)
|
timesteps = timesteps.to(original_samples.device)
|
||||||
|
|
||||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
|
||||||
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||||
|
|
Loading…
Reference in New Issue