Allow dtype to be specified in Flax pipeline (#600)
* Fix typo in docstring. * Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. * Create latents in float32 The denoising loop always computes the next step in float32, so this would fail when using `bfloat16`.
This commit is contained in:
parent
fb03aad8b4
commit
fb2fbab10b
|
@ -154,9 +154,12 @@ class ConfigMixin:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
|
||||||
|
|
||||||
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
# Allow dtype to be specified on initialization
|
||||||
|
if "dtype" in unused_kwargs:
|
||||||
|
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
||||||
|
|
||||||
model = cls(**init_dict)
|
model = cls(**init_dict)
|
||||||
|
|
||||||
if return_unused_kwargs:
|
if return_unused_kwargs:
|
||||||
|
|
|
@ -30,7 +30,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||||
Tokenizer of class
|
Tokenizer of class
|
||||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||||
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
unet ([`FlaxUNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||||
scheduler ([`FlaxSchedulerMixin`]):
|
scheduler ([`SchedulerMixin`]):
|
||||||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
|
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
|
||||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||||
|
@ -157,7 +157,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||||
self.unet.sample_size,
|
self.unet.sample_size,
|
||||||
)
|
)
|
||||||
if latents is None:
|
if latents is None:
|
||||||
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=self.dtype)
|
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
|
||||||
else:
|
else:
|
||||||
if latents.shape != latents_shape:
|
if latents.shape != latents_shape:
|
||||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||||
|
|
Loading…
Reference in New Issue