[Type hint] scheduling karras ve (#359)
This commit is contained in:
parent
07f8ebd543
commit
3c1cdd3359
|
@ -14,7 +14,7 @@
|
||||||
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -54,13 +54,13 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
sigma_min=0.02,
|
sigma_min: float = 0.02,
|
||||||
sigma_max=100,
|
sigma_max: float = 100,
|
||||||
s_noise=1.007,
|
s_noise: float = 1.007,
|
||||||
s_churn=80,
|
s_churn: float = 80,
|
||||||
s_min=0.05,
|
s_min: float = 0.05,
|
||||||
s_max=50,
|
s_max: float = 50,
|
||||||
tensor_format="pt",
|
tensor_format: str = "pt",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of
|
||||||
|
@ -87,7 +87,7 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
self.tensor_format = tensor_format
|
self.tensor_format = tensor_format
|
||||||
self.set_format(tensor_format=tensor_format)
|
self.set_format(tensor_format=tensor_format)
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps):
|
def set_timesteps(self, num_inference_steps: int):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
self.timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||||
self.schedule = [
|
self.schedule = [
|
||||||
|
@ -98,7 +98,9 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
self.set_format(tensor_format=self.tensor_format)
|
self.set_format(tensor_format=self.tensor_format)
|
||||||
|
|
||||||
def add_noise_to_input(self, sample, sigma, generator=None):
|
def add_noise_to_input(
|
||||||
|
self, sample: Union[torch.FloatTensor, np.ndarray], sigma: float, generator: Optional[torch.Generator] = None
|
||||||
|
) -> Tuple[Union[torch.FloatTensor, np.ndarray], float]:
|
||||||
"""
|
"""
|
||||||
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
|
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
|
||||||
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
|
higher noise level sigma_hat = sigma_i + gamma_i*sigma_i.
|
||||||
|
|
Loading…
Reference in New Issue