* Fix Flax pipeline: width and height are ignored #838 * Fix Flax pipeline: width and height are ignored
This commit is contained in:
parent
1d3234cbca
commit
93a81a3f5a
|
@ -152,12 +152,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
self.unet.sample_size,
|
||||
self.unet.sample_size,
|
||||
)
|
||||
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue