[Type hint] scheduling karras ve (#359)

This commit is contained in:
Santiago Víquez 2022-09-05 18:20:57 +02:00 committed by GitHub
parent 07f8ebd543
commit 3c1cdd3359
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 10 deletions

View File

@ -14,7 +14,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Tuple, Union from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -54,13 +54,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
sigma_min=0.02, sigma_min: float = 0.02,
sigma_max=100, sigma_max: float = 100,
s_noise=1.007, s_noise: float = 1.007,
s_churn=80, s_churn: float = 80,
s_min=0.05, s_min: float = 0.05,
s_max=50, s_max: float = 50,
tensor_format="pt", tensor_format: str = "pt",
): ):
""" """
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
@ -87,7 +87,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps): def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy() self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.schedule = [ self.schedule = [
@ -98,7 +98,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
self.set_format(tensor_format=self.tensor_format) self.set_format(tensor_format=self.tensor_format)
def add_noise_to_input(self, sample, sigma, generator=None): def add_noise_to_input(
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
""" """
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i 0 to reach a Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i 0 to reach a
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i. higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.