Fix Flax pipeline: width and height are ignored #838 (#848)

* Fix Flax pipeline: width and height are ignored #838

* Fix Flax pipeline: width and height are ignored
This commit is contained in:
camenduru 2022-10-14 22:43:56 +03:00 committed by GitHub
parent 1d3234cbca
commit 93a81a3f5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 6 deletions

View File

@ -152,12 +152,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
context = jnp.concatenate([uncond_embeddings, text_embeddings])
latents_shape = (
batch_size,
self.unet.in_channels,
self.unet.sample_size,
self.unet.sample_size,
)
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else: