From a127363dcabdc4c0625ef24be0e0d8143de18af2 Mon Sep 17 00:00:00 2001 From: Daniel Hug <38571110+danielpatrickhug@users.noreply.github.com> Date: Thu, 8 Sep 2022 03:17:14 -0400 Subject: [PATCH] Add typing to scheduling_sde_ve: init, set_timesteps, and set_sigmas function definitions (#412) Add typing to scheduling_sde_ve init, set_timesteps, and set_sigmas functions Co-authored-by: Patrick von Platen --- src/diffusers/schedulers/scheduling_sde_ve.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index 7e203db6..308f42c9 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -65,13 +65,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): @register_to_config def __init__( self, - num_train_timesteps=2000, - snr=0.15, - sigma_min=0.01, - sigma_max=1348, - sampling_eps=1e-5, - correct_steps=1, - tensor_format="pt", + num_train_timesteps: int = 2000, + snr: float = 0.15, + sigma_min: float = 0.01, + sigma_max: float = 1348.0, + sampling_eps: float = 1e-5, + correct_steps: int = 1, + tensor_format: str = "pt", ): # setable values self.timesteps = None @@ -81,7 +81,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): self.tensor_format = tensor_format self.set_format(tensor_format=tensor_format) - def set_timesteps(self, num_inference_steps, sampling_eps=None): + def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None): """ Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -100,7 +100,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): else: raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") - def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None): + def set_sigmas( + self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None + ): """ Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.