From 2a8477de5ca73d1b492e7ee5205bfcb1d7497c97 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 19 Sep 2022 16:50:22 +0200 Subject: [PATCH] [Flax] Solve problem with VAE (#574) --- src/diffusers/models/vae_flax.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index e3906c09..793010e8 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -600,7 +600,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() - hidden_states = self.decode(hidden_states, return_dict=return_dict).sample + + sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,)