Add typing to scheduling_sde_ve: init, set_timesteps, and set_sigmas function definitions (#412)
Add typing to scheduling_sde_ve init, set_timesteps, and set_sigmas functions Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
b8894f181d
commit
a127363dca
|
@ -65,13 +65,13 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=2000,
|
||||
snr=0.15,
|
||||
sigma_min=0.01,
|
||||
sigma_max=1348,
|
||||
sampling_eps=1e-5,
|
||||
correct_steps=1,
|
||||
tensor_format="pt",
|
||||
num_train_timesteps: int = 2000,
|
||||
snr: float = 0.15,
|
||||
sigma_min: float = 0.01,
|
||||
sigma_max: float = 1348.0,
|
||||
sampling_eps: float = 1e-5,
|
||||
correct_steps: int = 1,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
# setable values
|
||||
self.timesteps = None
|
||||
|
@ -81,7 +81,7 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps, sampling_eps=None):
|
||||
def set_timesteps(self, num_inference_steps: int, sampling_eps: float = None):
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
@ -100,7 +100,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
else:
|
||||
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
|
||||
|
||||
def set_sigmas(self, num_inference_steps, sigma_min=None, sigma_max=None, sampling_eps=None):
|
||||
def set_sigmas(
|
||||
self, num_inference_steps: int, sigma_min: float = None, sigma_max: float = None, sampling_eps: float = None
|
||||
):
|
||||
"""
|
||||
Sets the noise scales used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
|
|
Loading…
Reference in New Issue