Add typing to scheduling_sde_ve: init, set_timesteps, and set_sigmas function definitions (#412)

Add typing to scheduling_sde_ve init, set_timesteps, and set_sigmas functions

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Daniel Hug 2022-09-08 03:17:14 -04:00 committed by GitHub
parent b8894f181d
commit a127363dca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 9 deletions

View File

@ -65,13 +65,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
num_train_timesteps=2000, num_train_timesteps: int = 2000,
snr=0.15, snr: float = 0.15,
sigma_min=0.01, sigma_min: float = 0.01,
sigma_max=1348, sigma_max: float = 1348.0,
sampling_eps=1e-5, sampling_eps: float = 1e-5,
correct_steps=1, correct_steps: int = 1,
tensor_format="pt", tensor_format: str = "pt",
): ):
# setable values # setable values
self.timesteps = None self.timesteps = None
@ -81,7 +81,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps, sampling_eps=None): def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
""" """
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@ -100,7 +100,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.") raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None): def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
):
""" """
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference. Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.