Allow dtype to be specified in Flax pipeline (#600)
* Fix typo in docstring. * Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. * Create latents in float32 The denoising loop always computes the next step in float32, so this would fail when using `bfloat16`.
This commit is contained in:
parent
fb03aad8b4
commit
fb2fbab10b
|
@ -154,9 +154,12 @@ class ConfigMixin:
|
|||
|
||||
"""
|
||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# Allow dtype to be specified on initialization
|
||||
if "dtype" in unused_kwargs:
|
||||
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||
|
||||
model = cls(**init_dict)
|
||||
|
||||
if return_unused_kwargs:
|
||||
|
|
|
@ -30,7 +30,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlaxSchedulerMixin`]):
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
|
@ -157,7 +157,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
self.unet.sample_size,
|
||||
)
|
||||
if latents is None:
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
|
||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
|
|
Loading…
Reference in New Issue