From 93a81a3f5a7ee9eb1ee520010bec86e64ffc901e Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Fri, 14 Oct 2022 22:43:56 +0300 Subject: [PATCH] Fix Flax pipeline: width and height are ignored #838 (#848) * Fix Flax pipeline: width and height are ignored #838 * Fix Flax pipeline: width and height are ignored --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 7e58d048..18c008f8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -152,12 +152,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) - latents_shape = ( - batch_size, - self.unet.in_channels, - self.unet.sample_size, - self.unet.sample_size, - ) + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) if latents is None: latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32) else: