Add typing to scheduling_sde_ve: init, set_timesteps, and set_sigmas function definitions (#412)

Add typing to scheduling_sde_ve init, set_timesteps, and set_sigmas functions

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Daniel Hug 2022-09-08 03:17:14 -04:00 committed by GitHub
parent b8894f181d
commit a127363dca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 9 deletions

View File

@ -65,13 +65,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=2000,
snr=0.15,
sigma_min=0.01,
sigma_max=1348,
sampling_eps=1e-5,
correct_steps=1,
tensor_format="pt",
num_train_timesteps: int = 2000,
snr: float = 0.15,
sigma_min: float = 0.01,
sigma_max: float = 1348.0,
sampling_eps: float = 1e-5,
correct_steps: int = 1,
tensor_format: str = "pt",
):
# setable values
self.timesteps = None
@ -81,7 +81,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps, sampling_eps=None):
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@ -100,7 +100,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
def set_sigmas(
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
):
"""
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.