[Type hint] scheduling karras ve (#359)

This commit is contained in:
Santiago Víquez 2022-09-05 18:20:57 +02:00 committed by GitHub
parent 07f8ebd543
commit 3c1cdd3359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 10 deletions

View File

@ -14,7 +14,7 @@
from dataclasses import dataclass
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@ -54,13 +54,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
sigma_min=0.02,
sigma_max=100,
s_noise=1.007,
s_churn=80,
s_min=0.05,
s_max=50,
tensor_format="pt",
sigma_min: float = 0.02,
sigma_max: float = 100,
s_noise: float = 1.007,
s_churn: float = 80,
s_min: float = 0.05,
s_max: float = 50,
tensor_format: str = "pt",
):
"""
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
@ -87,7 +87,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [
@ -98,7 +98,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=self.tensor_format)
def add_noise_to_input(self, sample, sigma, generator=None):
def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.