Allow dtype to be specified in Flax pipeline (#600)

* Fix typo in docstring.

* Allow dtype to be overridden on model load.

This may be a temporary solution until #567 is addressed.

* Create latents in float32

The denoising loop always computes the next step in float32, so this
would fail when using `bfloat16`.
This commit is contained in:
Pedro Cuenca 2022-09-21 10:57:01 +02:00 committed by GitHub
parent fb03aad8b4
commit fb2fbab10b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 3 deletions

View File

@ -154,9 +154,12 @@ class ConfigMixin:
"""
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
# Allow dtype to be specified on initialization
if "dtype" in unused_kwargs:
init_dict["dtype"] = unused_kwargs.pop("dtype")
model = cls(**init_dict)
if return_unused_kwargs:

View File

@ -30,7 +30,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`FlaxSchedulerMixin`]):
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
@ -157,7 +157,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
self.unet.sample_size,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")