diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 77003f56..36717a8c 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -250,6 +250,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 2b02bfa6..43a73b03 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -251,6 +251,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 750ec931..19272d62 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -120,7 +120,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): frac = np.mod(self.timesteps, 1.0) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx] - self.sigmas = np.concatenate([sigmas, [0.0]]) + self.sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) self.derivatives = [] @@ -183,6 +183,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> Union[torch.FloatTensor, np.ndarray]: + if self.tensor_format == "pt": + timesteps = timesteps.to(self.sigmas.device) sigmas = self.match_shape(self.sigmas[timesteps], noise) noisy_samples = original_samples + noise * sigmas diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index bbcc1e1f..bfe31aca 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -367,8 +367,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): noise: Union[torch.FloatTensor, np.ndarray], timesteps: Union[torch.IntTensor, np.ndarray], ) -> torch.Tensor: - # mps requires indices to be in the same device, so we use cpu as is the default with cuda - timesteps = timesteps.to(self.alphas_cumprod.device) + if self.tensor_format == "pt": + timesteps = timesteps.to(self.alphas_cumprod.device) sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples) sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5