diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index e3906c09..793010e8 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -600,7 +600,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): hidden_states = posterior.latent_dist.sample(rng) else: hidden_states = posterior.latent_dist.mode() - hidden_states = self.decode(hidden_states, return_dict=return_dict).sample + + sample = self.decode(hidden_states, return_dict=return_dict).sample if not return_dict: return (sample,)