From 249b36cc38ba3554102c32b3a78705324d23bf62 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 3 Oct 2022 15:07:09 +0200 Subject: [PATCH] Flax: add shape argument to `set_timesteps` (#690) * Flax: add shape argument to set_timesteps * style --- src/diffusers/schedulers/scheduling_ddim_flax.py | 2 +- src/diffusers/schedulers/scheduling_ddpm_flax.py | 2 +- src/diffusers/schedulers/scheduling_karras_ve_flax.py | 4 +++- src/diffusers/schedulers/scheduling_lms_discrete_flax.py | 4 +++- src/diffusers/schedulers/scheduling_pndm_flax.py | 2 +- src/diffusers/schedulers/scheduling_sde_ve_flax.py | 2 +- 6 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index d81d6660..a57f61e0 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -156,7 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin): return variance - def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState: + def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 7c7b8d29..8f631403 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -133,7 +133,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin): self.variance_type = variance_type - def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState: + def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index c320b79e..72ff69da 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -99,7 +99,9 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin): ): self.state = KarrasVeSchedulerState.create() - def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState: + def set_timesteps( + self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple + ) -> KarrasVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 4784e4fa..51b3dfc7 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -111,7 +111,9 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin): return integrated_coeff - def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState: + def set_timesteps( + self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple + ) -> LMSDiscreteSchedulerState: """ Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 9e2b19f0..871a406f 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -156,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin): def create_state(self): return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps) - def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState: + def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index 08fbe147..3c8834d7 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -95,7 +95,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps) def set_timesteps( - self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None + self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None ) -> ScoreSdeVeSchedulerState: """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.