[Type hint] PNDM schedulers (#335)
* [Type hint] PNDM Schedulers * ran make style * updated timesteps type hint * apply suggestions from code review * ran make style * removed unused import
This commit is contained in:
parent
6c0ca5efa6
commit
dea5ec508f
|
@ -51,12 +51,12 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
tensor_format="pt",
|
||||
skip_prk_steps=False,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
tensor_format: str = "pt",
|
||||
skip_prk_steps: bool = False,
|
||||
):
|
||||
|
||||
if beta_schedule == "linear":
|
||||
|
@ -97,7 +97,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.tensor_format = tensor_format
|
||||
self.set_format(tensor_format=tensor_format)
|
||||
|
||||
def set_timesteps(self, num_inference_steps, offset=0):
|
||||
def set_timesteps(self, num_inference_steps: int, offset: int = 0) -> torch.FloatTensor:
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self._timesteps = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
|
@ -264,7 +264,13 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return prev_sample
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> torch.Tensor:
|
||||
|
||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||
|
|
Loading…
Reference in New Issue