diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index 2897a3a0..90a1aa00 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -487,7 +487,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin): self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - def encode(self, x, return_dict: bool = True): + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments)