Flax: add shape argument to `set_timesteps` (#690)
* Flax: add shape argument to set_timesteps * style
This commit is contained in:
parent
500ca5a907
commit
249b36cc38
|
@ -156,7 +156,7 @@ class FlaxDDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return variance
|
||||
|
||||
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int) -> DDIMSchedulerState:
|
||||
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
|
|
@ -133,7 +133,7 @@ class FlaxDDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int) -> DDPMSchedulerState:
|
||||
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
|
|
@ -99,7 +99,9 @@ class FlaxKarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
):
|
||||
self.state = KarrasVeSchedulerState.create()
|
||||
|
||||
def set_timesteps(self, state: KarrasVeSchedulerState, num_inference_steps: int) -> KarrasVeSchedulerState:
|
||||
def set_timesteps(
|
||||
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
) -> KarrasVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
|
|
@ -111,7 +111,9 @@ class FlaxLMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(self, state: LMSDiscreteSchedulerState, num_inference_steps: int) -> LMSDiscreteSchedulerState:
|
||||
def set_timesteps(
|
||||
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
|
||||
) -> LMSDiscreteSchedulerState:
|
||||
"""
|
||||
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
|
|
@ -156,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
def create_state(self):
|
||||
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
|
||||
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, shape: Tuple) -> PNDMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ class FlaxScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
|
||||
def set_timesteps(
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, sampling_eps: float = None
|
||||
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
|
||||
) -> ScoreSdeVeSchedulerState:
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
|
Loading…
Reference in New Issue