From 78db11dbf3f90150203316e60b2bf633982945aa Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Thu, 13 Oct 2022 17:01:47 +0200 Subject: [PATCH] Flax safety checker (#825) * Remove set_format in Flax pipeline. * Remove DummyChecker. * Run safety_checker in pipeline. * Don't pmap on every call. We could have decorated `generate` with `pmap`, but I wanted to keep it in case someone wants to invoke it in non-parallel mode. * Remove commented line Co-authored-by: Patrick von Platen * Replicate outside __call__, prepare for optional jitting. * Remove unnecessary clipping. As suggested by @kashif. * Do not jit unless requested. * Send all args to generate. * make style * Remove unused imports. * Fix docstring. Co-authored-by: Patrick von Platen --- src/diffusers/pipeline_flax_utils.py | 22 +- .../pipeline_flax_stable_diffusion.py | 191 ++++++++++++------ 2 files changed, 135 insertions(+), 78 deletions(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 92b71cae..b3ac2729 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -62,11 +62,6 @@ for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) -class DummyChecker: - def __init__(self): - self.dummy = True - - def import_flax_or_no_model(module, class_name): try: # 1. First make sure that if a Flax object is present, import this one @@ -177,10 +172,6 @@ class FlaxDiffusionPipeline(ConfigMixin): if save_method_name is not None: break - # TODO(Patrick, Suraj): to delete after - if isinstance(sub_model, DummyChecker): - continue - save_method = getattr(sub_model, save_method_name) expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) @@ -194,7 +185,7 @@ class FlaxDiffusionPipeline(ConfigMixin): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" - Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. + Instantiate a Flax diffusion pipeline from pre-trained pipeline weights. The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). @@ -349,11 +340,6 @@ class FlaxDiffusionPipeline(ConfigMixin): # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): - # TODO(Patrick, Suraj) - delete later - if class_name == "DummyChecker": - library_name = "stable_diffusion" - class_name = "FlaxStableDiffusionSafetyChecker" - is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None @@ -422,11 +408,7 @@ class FlaxDiffusionPipeline(ConfigMixin): loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) params[name] = loaded_params elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): - # make sure we don't initialize the weights to save time - if name == "safety_checker": - loaded_sub_model = DummyChecker() - loaded_params = {} - elif from_pt: + if from_pt: # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) loaded_params = loaded_sub_model.params diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 81c6bd80..1f7607fa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,8 +1,14 @@ +from functools import partial from typing import Dict, List, Optional, Union +import numpy as np + import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict +from flax.jax_utils import unreplicate +from flax.training.common_utils import shard +from PIL import Image from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel @@ -77,60 +83,44 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): ) return text_input.input_ids - def __call__( + def _get_safety_scores(self, features, params): + special_cos_dist, cos_dist = self.safety_checker(features, params) + return (special_cos_dist, cos_dist) + + def _run_safety_checker(self, images, safety_model_params, jit=False): + # safety_model_params should already be replicated when jit is True + pil_images = [Image.fromarray(image) for image in images] + features = self.feature_extractor(pil_images, return_tensors="np").pixel_values + + if jit: + features = shard(features) + special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params) + special_cos_dist = unshard(special_cos_dist) + cos_dist = unshard(cos_dist) + safety_model_params = unreplicate(safety_model_params) + else: + special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params) + + images, has_nsfw = self.safety_checker.filtered_with_scores( + special_cos_dist, + cos_dist, + images, + safety_model_params, + ) + return images, has_nsfw + + def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], prng_seed: jax.random.PRNGKey, - num_inference_steps: Optional[int] = 50, - height: Optional[int] = 512, - width: Optional[int] = 512, - guidance_scale: Optional[float] = 7.5, + num_inference_steps: int = 50, + height: int = 512, + width: int = 512, + guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, - return_dict: bool = True, debug: bool = False, - **kwargs, ): - r""" - Function invoked when calling the pipeline for generation. - - Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - height (`int`, *optional*, defaults to 512): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to 512): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): - Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). - `guidance_scale` is defined as `w` of equation 2. of [Imagen - Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > - 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, - usually at the expense of lower image quality. - generator (`torch.Generator`, *optional*): - A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation - deterministic. - latents (`jnp.array`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generate image. Choose between - [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of - a plain tuple. - - Returns: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. - """ if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -203,21 +193,106 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - # TODO: check when flax vae gets merged into main image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) + return image - # image = jnp.asarray(image).transpose(0, 2, 3, 1) - # run safety checker - # TODO: check when flax safety checker gets merged into main - # safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") - # image, has_nsfw_concept = self.safety_checker( - # images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"] - # ) - has_nsfw_concept = False + def __call__( + self, + prompt_ids: jnp.array, + params: Union[Dict, FrozenDict], + prng_seed: jax.random.PRNGKey, + num_inference_steps: int = 50, + height: int = 512, + width: int = 512, + guidance_scale: float = 7.5, + latents: jnp.array = None, + return_dict: bool = True, + jit: bool = False, + debug: bool = False, + **kwargs, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`): + The prompt or prompts to guide the image generation. + height (`int`, *optional*, defaults to 512): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to 512): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + generator (`torch.Generator`, *optional*): + A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation + deterministic. + latents (`jnp.array`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + jit (`bool`, defaults to `False`): + Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument + exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of + a plain tuple. + + Returns: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a + `tuple. When returning a tuple, the first element is a list with the generated images, and the second + element is a list of `bool`s denoting whether the corresponding generated image likely represents + "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + """ + if jit: + images = _p_generate( + self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + ) + else: + images = self._generate( + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + ) + + safety_params = params["safety_checker"] + images = (images * 255).round().astype("uint8") + images = np.asarray(images).reshape(-1, height, width, 3) + images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit) if not return_dict: - return (image, has_nsfw_concept) + return (images, has_nsfw_concept) - return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) + + +# TODO: maybe use a config dict instead of so many static argnums +@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) +def _p_generate( + pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug +): + return pipe._generate( + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + ) + + +@partial(jax.pmap, static_broadcasted_argnums=(0,)) +def _p_get_safety_scores(pipe, features, params): + return pipe._get_safety_scores(features, params) + + +def unshard(x: jnp.ndarray): + # einops.rearrange(x, 'd b ... -> (d b) ...') + num_devices, batch_size = x.shape[:2] + rest = x.shape[2:] + return x.reshape(num_devices * batch_size, *rest)