[Type hint] PNDM schedulers (#335)

* [Type hint] PNDM Schedulers

* ran make style

* updated timesteps type hint

* apply suggestions from code review

* ran make style

* removed unused import
This commit is contained in:
Partho 2022-09-04 21:31:57 +05:30 committed by GitHub
parent 6c0ca5efa6
commit dea5ec508f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 8 deletions

View File

@ -51,12 +51,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
tensor_format="pt",
skip_prk_steps=False,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
tensor_format: str = "pt",
skip_prk_steps: bool = False,
):
if beta_schedule == "linear":
@ -97,7 +97,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
def set_timesteps(self, num_inference_steps, offset=0):
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
self.num_inference_steps = num_inference_steps
self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
@ -264,7 +264,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return prev_sample
def add_noise(self, original_samples, noise, timesteps):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> torch.Tensor:
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5