[Flax] Solve problem with VAE (#574)

This commit is contained in:
Patrick von Platen 2022-09-19 16:50:22 +02:00 committed by GitHub
parent bf5ca036fa
commit 2a8477de5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -600,7 +600,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
hidden_states = posterior.latent_dist.sample(rng)
else:
hidden_states = posterior.latent_dist.mode()
hidden_states = self.decode(hidden_states, return_dict=return_dict).sample
sample = self.decode(hidden_states, return_dict=return_dict).sample
if not return_dict:
return (sample,)