diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 7e58d048..18c008f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -152,12 +152,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) - latents_shape = ( - batch_size, - self.unet.in_channels, - self.unet.sample_size, - self.unet.sample_size, - ) + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: