adding more typehints to DDIM scheduler (#456)
* adding more typehints * resolving mypy issues * resolving formatting issue * fixing isort issue Co-authored-by: V Vishnu Anirudh <git.vva@gmail.com> Co-authored-by: V Vishnu Anirudh <vvani@kth.se>
This commit is contained in:
parent
06924c6a4f
commit
a0558b1146
|
@ -16,7 +16,7 @@
|
||||||
# and https://github.com/hojonathanho/diffusion
|
# and https://github.com/hojonathanho/diffusion
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||||
(1-beta) over time from t = [0,1].
|
(1-beta) over time from t = [0,1].
|
||||||
|
@ -43,14 +43,14 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def alpha_bar(time_step):
|
def calculate_alpha_bar(time_step: float) -> float:
|
||||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||||
|
|
||||||
betas = []
|
betas: List[float] = []
|
||||||
for i in range(num_diffusion_timesteps):
|
for diffusion_timestep in range(num_diffusion_timesteps):
|
||||||
t1 = i / num_diffusion_timesteps
|
lower_timestep = diffusion_timestep / num_diffusion_timesteps
|
||||||
t2 = (i + 1) / num_diffusion_timesteps
|
upper_timestep = (diffusion_timestep + 1) / num_diffusion_timesteps
|
||||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
betas.append(min(1 - calculate_alpha_bar(upper_timestep) / calculate_alpha_bar(lower_timestep), max_beta))
|
||||||
return np.array(betas, dtype=np.float32)
|
return np.array(betas, dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
@ -96,7 +96,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
tensor_format: str = "pt",
|
tensor_format: str = "pt",
|
||||||
):
|
):
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
self.betas = np.asarray(trained_betas)
|
self.betas: np.ndarray = np.asarray(trained_betas)
|
||||||
if beta_schedule == "linear":
|
if beta_schedule == "linear":
|
||||||
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
|
||||||
elif beta_schedule == "scaled_linear":
|
elif beta_schedule == "scaled_linear":
|
||||||
|
@ -108,8 +108,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
|
||||||
|
|
||||||
self.alphas = 1.0 - self.betas
|
self.alphas: np.ndarray = 1.0 - self.betas
|
||||||
self.alphas_cumprod = np.cumprod(self.alphas, axis=0)
|
self.alphas_cumprod: np.ndarray = np.cumprod(self.alphas, axis=0)
|
||||||
|
|
||||||
# At every step in ddim, we are looking into the previous alphas_cumprod
|
# At every step in ddim, we are looking into the previous alphas_cumprod
|
||||||
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
||||||
|
@ -118,10 +118,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||||
|
|
||||||
# setable values
|
# setable values
|
||||||
self.num_inference_steps = None
|
self.num_inference_steps: int = 0
|
||||||
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
|
self.timesteps: np.ndarray = np.arange(0, num_train_timesteps)[::-1].copy()
|
||||||
|
|
||||||
self.tensor_format = tensor_format
|
self.tensor_format: str = tensor_format
|
||||||
self.set_format(tensor_format=tensor_format)
|
self.set_format(tensor_format=tensor_format)
|
||||||
|
|
||||||
def _get_variance(self, timestep, prev_timestep):
|
def _get_variance(self, timestep, prev_timestep):
|
||||||
|
@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return variance
|
return variance
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps: int, offset: int = 0):
|
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> None:
|
||||||
"""
|
"""
|
||||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue