type hints: models/vae.py (#346)
* type hints: models/vae.py * modify typings in vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
parent
ada09bd3f0
commit
07f8ebd543
|
@ -487,7 +487,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||||
|
|
||||||
def encode(self, x, return_dict: bool = True):
|
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||||
h = self.encoder(x)
|
h = self.encoder(x)
|
||||||
moments = self.quant_conv(h)
|
moments = self.quant_conv(h)
|
||||||
posterior = DiagonalGaussianDistribution(moments)
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
|
Loading…
Reference in New Issue