From 07f8ebd543fa36b2ec5bfcc80d56eadb308b47f5 Mon Sep 17 00:00:00 2001 From: Samuel Ajisegiri Date: Mon, 5 Sep 2022 15:46:12 +0100 Subject: [PATCH] type hints: models/vae.py (#346) * type hints: models/vae.py * modify typings in vae.py Co-authored-by: Patrick von Platen Co-authored-by: Anton Lozhkov --- src/diffusers/models/vae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 2897a3a0..90a1aa00 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -487,7 +487,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - def encode(self, x, return_dict: bool = True): + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments)