type hints: models/vae.py (#346)
* type hints: models/vae.py * modify typings in vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
parent
ada09bd3f0
commit
07f8ebd543
|
@ -487,7 +487,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
def encode(self, x, return_dict: bool = True):
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
|
Loading…
Reference in New Issue