adding more typehints to DDIM scheduler (#456)

* adding more typehints

* resolving mypy issues

* resolving formatting issue

* fixing isort issue

Co-authored-by: V Vishnu Anirudh <git.vva@gmail.com>
Co-authored-by: V Vishnu Anirudh <vvani@kth.se>
This commit is contained in:
V Vishnu Anirudh 2022-09-16 16:41:58 +01:00 committed by GitHub
parent 06924c6a4f
commit a0558b1146
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 15 deletions

View File

@ -16,7 +16,7 @@
# and https://github.com/hojonathanho/diffusion # and https://github.com/hojonathanho/diffusion
import math import math
from typing import Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -25,7 +25,7 @@ from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta: float = 0.999) -> np.ndarray:
""" """
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1]. (1-beta) over time from t = [0,1].
@ -43,14 +43,14 @@ def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
""" """
def alpha_bar(time_step): def calculate_alpha_bar(time_step: float) -> float:
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = [] betas: List[float] = []
for i in range(num_diffusion_timesteps): for diffusion_timestep in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps lower_timestep = diffusion_timestep / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps upper_timestep = (diffusion_timestep + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) betas.append(min(1 - calculate_alpha_bar(upper_timestep) / calculate_alpha_bar(lower_timestep), max_beta))
return np.array(betas, dtype=np.float32) return np.array(betas, dtype=np.float32)
@ -96,7 +96,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
tensor_format: str = "pt", tensor_format: str = "pt",
): ):
if trained_betas is not None: if trained_betas is not None:
self.betas = np.asarray(trained_betas) self.betas: np.ndarray = np.asarray(trained_betas)
if beta_schedule == "linear": if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32) self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear": elif beta_schedule == "scaled_linear":
@ -108,8 +108,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
else: else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas self.alphas: np.ndarray = 1.0 - self.betas
self.alphas_cumprod = np.cumprod(self.alphas, axis=0) self.alphas_cumprod: np.ndarray = np.cumprod(self.alphas, axis=0)
# At every step in ddim, we are looking into the previous alphas_cumprod # At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0 # For the final step, there is no previous alphas_cumprod because we are already at 0
@ -118,10 +118,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0] self.final_alpha_cumprod = np.array(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# setable values # setable values
self.num_inference_steps = None self.num_inference_steps: int = 0
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() self.timesteps: np.ndarray = np.arange(0, num_train_timesteps)[::-1].copy()
self.tensor_format = tensor_format self.tensor_format: str = tensor_format
self.set_format(tensor_format=tensor_format) self.set_format(tensor_format=tensor_format)
def _get_variance(self, timestep, prev_timestep): def _get_variance(self, timestep, prev_timestep):
@ -134,7 +134,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
return variance return variance
def set_timesteps(self, num_inference_steps: int, offset: int = 0): def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> None:
""" """
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.