[Type Hint] DDPM schedulers (#349)
This commit is contained in:
parent
dea5ec508f
commit
878af0e113
|
@ -15,7 +15,7 @@
|
||||||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -51,14 +51,14 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_train_timesteps=1000,
|
num_train_timesteps: int = 1000,
|
||||||
beta_start=0.0001,
|
beta_start: float = 0.0001,
|
||||||
beta_end=0.02,
|
beta_end: float = 0.02,
|
||||||
beta_schedule="linear",
|
beta_schedule: str = "linear",
|
||||||
trained_betas=None,
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
variance_type="fixed_small",
|
variance_type: str = "fixed_small",
|
||||||
clip_sample=True,
|
clip_sample: bool = True,
|
||||||
tensor_format="pt",
|
tensor_format: str = "pt",
|
||||||
):
|
):
|
||||||
|
|
||||||
if trained_betas is not None:
|
if trained_betas is not None:
|
||||||
|
@ -87,7 +87,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
self.variance_type = variance_type
|
self.variance_type = variance_type
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps):
|
def set_timesteps(self, num_inference_steps: int):
|
||||||
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.timesteps = np.arange(
|
self.timesteps = np.arange(
|
||||||
|
@ -179,7 +179,13 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return {"prev_sample": pred_prev_sample}
|
return {"prev_sample": pred_prev_sample}
|
||||||
|
|
||||||
def add_noise(self, original_samples, noise, timesteps):
|
def add_noise(
|
||||||
|
self,
|
||||||
|
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||||
|
noise: Union[torch.FloatTensor, np.ndarray],
|
||||||
|
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||||
|
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||||
|
|
||||||
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
|
||||||
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
sqrt_alpha_prod = self.match_shape(sqrt_alpha_prod, original_samples)
|
||||||
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
|
||||||
|
|
Loading…
Reference in New Issue