From 7c2262640bbf9fa61c281bc49eb8494cb48da81f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 13 Oct 2022 17:43:06 +0200 Subject: [PATCH] Align PT and Flax API - allow loading checkpoint from PyTorch configs (#827) * up * finish * add more tests * up * up * finish --- src/diffusers/pipeline_flax_utils.py | 55 ++++++---- .../pipeline_flax_stable_diffusion.py | 35 +++++- .../pipeline_stable_diffusion.py | 2 +- .../pipeline_stable_diffusion_img2img.py | 2 +- .../pipeline_stable_diffusion_inpaint.py | 2 +- tests/test_pipelines_flax.py | 100 +++++++++++++++++- 6 files changed, 165 insertions(+), 31 deletions(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index b3ac2729..d55338b5 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -111,24 +111,27 @@ class FlaxDiffusionPipeline(ConfigMixin): from diffusers import pipelines for name, module in kwargs.items(): - # retrieve library - library = module.__module__.split(".")[0] + if module is None: + register_dict = {name: (None, None)} + else: + # retrieve library + library = module.__module__.split(".")[0] - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir - # retrieve class_name - class_name = module.__class__.__name__ + # retrieve class_name + class_name = module.__class__.__name__ - register_dict = {name: (library, class_name)} + register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) @@ -320,6 +323,11 @@ class FlaxDiffusionPipeline(ConfigMixin): pipeline_class = cls else: diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + class_name = ( + config_dict["_class_name"] + if config_dict["_class_name"].startswith("Flax") + else "Flax" + config_dict["_class_name"] + ) pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) # some modules can be passed directly to the init @@ -342,6 +350,7 @@ class FlaxDiffusionPipeline(ConfigMixin): for name, (library_name, class_name) in init_dict.items(): is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None + sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: @@ -362,6 +371,12 @@ class FlaxDiffusionPipeline(ConfigMixin): f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) + elif passed_class_obj[name] is None: + logger.warn( + f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" + f" that this might lead to problems when using {pipeline_class} and is not recommended." + ) + sub_model_should_be_defined = False else: logger.warn( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" @@ -372,25 +387,19 @@ class FlaxDiffusionPipeline(ConfigMixin): loaded_sub_model = passed_class_obj[name] elif is_pipeline_module: pipeline_module = getattr(pipelines, library_name) - if from_pt: - class_obj = import_flax_or_no_model(pipeline_module, class_name) - else: - class_obj = getattr(pipeline_module, class_name) + class_obj = import_flax_or_no_model(pipeline_module, class_name) importable_classes = ALL_IMPORTABLE_CLASSES class_candidates = {c: class_obj for c in importable_classes.keys()} else: # else we just import it from the library. library = importlib.import_module(library_name) - if from_pt: - class_obj = import_flax_or_no_model(library, class_name) - else: - class_obj = getattr(library, class_name) + class_obj = import_flax_or_no_model(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - if loaded_sub_model is None: + if loaded_sub_model is None and sub_model_should_be_defined: load_method_name = None for class_name, class_candidate in class_candidates.items(): if issubclass(class_obj, class_candidate): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 1f7607fa..6cd67829 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -14,10 +14,14 @@ from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...pipeline_flax_utils import FlaxDiffusionPipeline from ...schedulers import FlaxDDIMScheduler, FlaxLMSDiscreteScheduler, FlaxPNDMScheduler +from ...utils import logging from . import FlaxStableDiffusionPipelineOutput from .safety_checker_flax import FlaxStableDiffusionSafetyChecker +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -60,6 +64,16 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): super().__init__() self.dtype = dtype + if safety_checker is None: + logger.warn( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -265,10 +279,23 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug ) - safety_params = params["safety_checker"] - images = (images * 255).round().astype("uint8") - images = np.asarray(images).reshape(-1, height, width, 3) - images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit) + if self.safety_checker is not None: + safety_params = params["safety_checker"] + images_uint8_casted = (images * 255).round().astype("uint8") + num_devices, batch_size = images.shape[:2] + + images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) + images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) + images = np.asarray(images) + + # block images + if any(has_nsfw_concept): + for i, is_nsfw in enumerate(has_nsfw_concept): + images[i] = np.asarray(images_uint8_casted[i]) + + images = images.reshape(num_devices, batch_size, height, width, 3) + else: + has_nsfw_concept = False if not return_dict: return (images, has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index bf8ba349..8ae51999 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -73,7 +73,7 @@ class StableDiffusionPipeline(DiffusionPipeline): if safety_checker is None: logger.warn( - f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index a9507450..799fd459 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -85,7 +85,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): if safety_checker is None: logger.warn( - f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1e4a5e77..cb4d552a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -100,7 +100,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if safety_checker is None: logger.warn( - f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" " results in services or applications open to the public. Both the diffusers team and Hugging Face" " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" diff --git a/tests/test_pipelines_flax.py b/tests/test_pipelines_flax.py index 18dce171..bcf71dcc 100644 --- a/tests/test_pipelines_flax.py +++ b/tests/test_pipelines_flax.py @@ -23,6 +23,7 @@ from diffusers.utils.testing_utils import require_flax, slow if is_flax_available(): import jax + import jax.numpy as jnp from diffusers import FlaxStableDiffusionPipeline from flax.jax_utils import replicate from flax.training.common_utils import shard @@ -34,7 +35,7 @@ if is_flax_available(): class FlaxPipelineTests(unittest.TestCase): def test_dummy_all_tpus(self): pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - "hf-internal-testing/tiny-stable-diffusion-pipe" + "hf-internal-testing/tiny-stable-diffusion-pipe", safety_checker=None ) prompt = ( @@ -57,6 +58,103 @@ class FlaxPipelineTests(unittest.TestCase): prompt_ids = shard(prompt_ids) images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + + assert images.shape == (8, 1, 64, 64, 3) + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 4.151474)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 49947.875)) < 1e-2 + images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) assert len(images_pil) == 8 + + def test_stable_diffusion_v1_4(self): + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="flax", safety_checker=None + ) + + prompt = ( + "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" + " field, close up, split lighting, cinematic" + ) + + prng_seed = jax.random.PRNGKey(0) + num_inference_steps = 50 + + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = pipeline.prepare_inputs(prompt) + + p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) + + # shard inputs and rng + params = replicate(params) + prng_seed = jax.random.split(prng_seed, 8) + prompt_ids = shard(prompt_ids) + + images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + + images_pil = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:]))) + for i, image in enumerate(images_pil): + image.save(f"/home/patrick/images/flax-test-{i}_fp32.png") + + assert images.shape == (8, 1, 512, 512, 3) + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.05652401)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2383808.2)) < 1e-2 + + def test_stable_diffusion_v1_4_bfloat_16(self): + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None + ) + + prompt = ( + "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" + " field, close up, split lighting, cinematic" + ) + + prng_seed = jax.random.PRNGKey(0) + num_inference_steps = 50 + + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = pipeline.prepare_inputs(prompt) + + p_sample = pmap(pipeline.__call__, static_broadcasted_argnums=(3,)) + + # shard inputs and rng + params = replicate(params) + prng_seed = jax.random.split(prng_seed, 8) + prompt_ids = shard(prompt_ids) + + images = p_sample(prompt_ids, params, prng_seed, num_inference_steps).images + + assert images.shape == (8, 1, 512, 512, 3) + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2 + + def test_stable_diffusion_v1_4_bfloat_16_with_safety(self): + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16 + ) + + prompt = ( + "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of" + " field, close up, split lighting, cinematic" + ) + + prng_seed = jax.random.PRNGKey(0) + num_inference_steps = 50 + + num_samples = jax.device_count() + prompt = num_samples * [prompt] + prompt_ids = pipeline.prepare_inputs(prompt) + + # shard inputs and rng + params = replicate(params) + prng_seed = jax.random.split(prng_seed, 8) + prompt_ids = shard(prompt_ids) + + images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images + + assert images.shape == (8, 1, 512, 512, 3) + assert np.abs((np.abs(images[0, 0, :2, :2, -2:], dtype=np.float32).sum() - 0.06652832)) < 1e-3 + assert np.abs((np.abs(images, dtype=np.float32).sum() - 2384849.8)) < 1e-2