[Flax] Solve problem with VAE (#574)
This commit is contained in:
parent
bf5ca036fa
commit
2a8477de5c
|
@ -600,7 +600,8 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||||
hidden_states = posterior.latent_dist.sample(rng)
|
hidden_states = posterior.latent_dist.sample(rng)
|
||||||
else:
|
else:
|
||||||
hidden_states = posterior.latent_dist.mode()
|
hidden_states = posterior.latent_dist.mode()
|
||||||
hidden_states = self.decode(hidden_states, return_dict=return_dict).sample
|
|
||||||
|
sample = self.decode(hidden_states, return_dict=return_dict).sample
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (sample,)
|
return (sample,)
|
||||||
|
|
Loading…
Reference in New Issue