[Flax] Solve problem with VAE (#574)
This commit is contained in:
parent
bf5ca036fa
commit
2a8477de5c
|
@ -600,7 +600,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
|||
hidden_states = posterior.latent_dist.sample(rng)
|
||||
else:
|
||||
hidden_states = posterior.latent_dist.mode()
|
||||
hidden_states = self.decode(hidden_states, return_dict=return_dict).sample
|
||||
|
||||
sample = self.decode(hidden_states, return_dict=return_dict).sample
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
|
Loading…
Reference in New Issue