[Type Hints] VAE models (#344)

* [Type Hints] VAE models

* apply suggestions from code review

apply suggestions to also return the return type
This commit is contained in:
Partho 2022-09-04 21:36:16 +05:30 committed by GitHub
parent 878af0e113
commit 5791f4acde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 24 additions and 22 deletions

View File

@ -1,3 +1,5 @@
from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -293,7 +295,7 @@ class DiagonalGaussianDistribution(object):
if self.deterministic: if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self, generator=None): def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device) x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
return x return x
@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
down_block_types=("DownEncoderBlock2D",), down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int] = (64,),
layers_per_block=1, layers_per_block: int = 1,
act_fn="silu", act_fn: str = "silu",
latent_channels=3, latent_channels: int = 3,
sample_size=32, sample_size: int = 32,
num_vq_embeddings=256, num_vq_embeddings: int = 256,
): ):
super().__init__() super().__init__()
@ -382,7 +384,7 @@ class VQModel(ModelMixin, ConfigMixin):
dec = self.decoder(quant) dec = self.decoder(quant)
return dec return dec
def forward(self, sample): def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
x = sample x = sample
h = self.encode(x) h = self.encode(x)
dec = self.decode(h) dec = self.decode(h)
@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels=3, in_channels: int = 3,
out_channels=3, out_channels: int = 3,
down_block_types=("DownEncoderBlock2D",), down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels=(64,), block_out_channels: Tuple[int] = (64,),
layers_per_block=1, layers_per_block: int = 1,
act_fn="silu", act_fn: str = "silu",
latent_channels=4, latent_channels: int = 4,
sample_size=32, sample_size: int = 32,
): ):
super().__init__() super().__init__()
@ -440,7 +442,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
dec = self.decoder(z) dec = self.decoder(z)
return dec return dec
def forward(self, sample, sample_posterior=False): def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
x = sample x = sample
posterior = self.encode(x) posterior = self.encode(x)
if sample_posterior: if sample_posterior: