From 3dacbb94ca25b7db876778aefbc41f9984b919e5 Mon Sep 17 00:00:00 2001 From: V Vishnu Anirudh Date: Thu, 29 Sep 2022 18:21:04 +0100 Subject: [PATCH] `trained_betas` ignored in some schedulers (#635) * correcting the beta value assignment * updating DDIM and LMSDiscreteFlax schedulers * bringing back the changes that were lost as part of main branch merge --- src/diffusers/schedulers/scheduling_ddim.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete.py | 2 +- src/diffusers/schedulers/scheduling_lms_discrete_flax.py | 2 +- src/diffusers/schedulers/scheduling_pndm.py | 2 +- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 0d9e285e..9079ba90 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -131,7 +131,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) - if beta_schedule == "linear": + elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 4595b2fe..ec445e72 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -86,7 +86,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) - if beta_schedule == "linear": + elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 303ef4e4..4784e4fa 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -74,7 +74,7 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) - if beta_schedule == "linear": + elif beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index d9e430c4..3015d153 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -111,7 +111,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): if trained_betas is not None: self.betas = torch.from_numpy(trained_betas) - if beta_schedule == "linear": + elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 4b417221..9e2b19f0 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -132,7 +132,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): ): if trained_betas is not None: self.betas = jnp.asarray(trained_betas) - if beta_schedule == "linear": + elif beta_schedule == "linear": self.betas = jnp.linspace(beta_start, beta_end, num_train_timesteps, dtype=jnp.float32) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model.