[Type Hints] VAE models (#344)
* [Type Hints] VAE models * apply suggestions from code review apply suggestions to also return the return type
This commit is contained in:
parent
878af0e113
commit
5791f4acde
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -293,7 +295,7 @@ class DiagonalGaussianDistribution(object):
|
||||||
if self.deterministic:
|
if self.deterministic:
|
||||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||||
|
|
||||||
def sample(self, generator=None):
|
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
||||||
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
|
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels=3,
|
in_channels: int = 3,
|
||||||
out_channels=3,
|
out_channels: int = 3,
|
||||||
down_block_types=("DownEncoderBlock2D",),
|
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||||
up_block_types=("UpDecoderBlock2D",),
|
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||||
block_out_channels=(64,),
|
block_out_channels: Tuple[int] = (64,),
|
||||||
layers_per_block=1,
|
layers_per_block: int = 1,
|
||||||
act_fn="silu",
|
act_fn: str = "silu",
|
||||||
latent_channels=3,
|
latent_channels: int = 3,
|
||||||
sample_size=32,
|
sample_size: int = 32,
|
||||||
num_vq_embeddings=256,
|
num_vq_embeddings: int = 256,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -382,7 +384,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||||
dec = self.decoder(quant)
|
dec = self.decoder(quant)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
def forward(self, sample):
|
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
x = sample
|
x = sample
|
||||||
h = self.encode(x)
|
h = self.encode(x)
|
||||||
dec = self.decode(h)
|
dec = self.decode(h)
|
||||||
|
@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels=3,
|
in_channels: int = 3,
|
||||||
out_channels=3,
|
out_channels: int = 3,
|
||||||
down_block_types=("DownEncoderBlock2D",),
|
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
|
||||||
up_block_types=("UpDecoderBlock2D",),
|
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
|
||||||
block_out_channels=(64,),
|
block_out_channels: Tuple[int] = (64,),
|
||||||
layers_per_block=1,
|
layers_per_block: int = 1,
|
||||||
act_fn="silu",
|
act_fn: str = "silu",
|
||||||
latent_channels=4,
|
latent_channels: int = 4,
|
||||||
sample_size=32,
|
sample_size: int = 32,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -440,7 +442,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||||
dec = self.decoder(z)
|
dec = self.decoder(z)
|
||||||
return dec
|
return dec
|
||||||
|
|
||||||
def forward(self, sample, sample_posterior=False):
|
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
|
||||||
x = sample
|
x = sample
|
||||||
posterior = self.encode(x)
|
posterior = self.encode(x)
|
||||||
if sample_posterior:
|
if sample_posterior:
|
||||||
|
|
Loading…
Reference in New Issue