Fix MPS scheduler indexing when using `mps` (#450)
* Fix LMS scheduler indexing in `add_noise` #358. * Fix DDIM and DDPM indexing with mps device. * Verify format is PyTorch before using `.to()`
This commit is contained in:
parent
7c4b38baca
commit
1a69c6ff0e
|
@ -250,6 +250,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
|
|
|
@ -251,6 +251,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
|
|
|
@ -120,7 +120,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
frac = np.mod(self.timesteps, 1.0)
|
||||
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
||||
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
|
||||
self.sigmas = np.concatenate([sigmas, [0.0]])
|
||||
self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
||||
|
||||
self.derivatives = []
|
||||
|
||||
|
@ -183,6 +183,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.sigmas.device)
|
||||
sigmas = self.match_shape(self.sigmas[timesteps], noise)
|
||||
noisy_samples = original_samples + noise * sigmas
|
||||
|
||||
|
|
|
@ -367,7 +367,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> torch.Tensor:
|
||||
# mps requires indices to be in the same device, so we use cpu as is the default with cuda
|
||||
if self.tensor_format == "pt":
|
||||
timesteps = timesteps.to(self.alphas_cumprod.device)
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
|
|
Loading…
Reference in New Issue