Flax safety checker (#825)

* Remove set_format in Flax pipeline.

* Remove DummyChecker.

* Run safety_checker in pipeline.

* Don't pmap on every call.

We could have decorated `generate` with `pmap`, but I wanted to keep it
in case someone wants to invoke it in non-parallel mode.

* Remove commented line

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Replicate outside __call__, prepare for optional jitting.

* Remove unnecessary clipping.

As suggested by @kashif.

* Do not jit unless requested.

* Send all args to generate.

* make style

* Remove unused imports.

* Fix docstring.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca 2022-10-13 17:01:47 +02:00 committed by GitHub
parent e713346ad1
commit 78db11dbf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 135 additions and 78 deletions

View File

@ -62,11 +62,6 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
class DummyChecker:
def __init__(self):
self.dummy = True
def import_flax_or_no_model(module, class_name): def import_flax_or_no_model(module, class_name):
try: try:
# 1. First make sure that if a Flax object is present, import this one # 1. First make sure that if a Flax object is present, import this one
@ -177,10 +172,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
if save_method_name is not None: if save_method_name is not None:
break break
# TODO(Patrick, Suraj): to delete after
if isinstance(sub_model, DummyChecker):
continue
save_method = getattr(sub_model, save_method_name) save_method = getattr(sub_model, save_method_name)
expects_params = "params" in set(inspect.signature(save_method).parameters.keys()) expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
@ -194,7 +185,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r""" r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights. Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
@ -349,11 +340,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline # 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items(): for name, (library_name, class_name) in init_dict.items():
# TODO(Patrick, Suraj) - delete later
if class_name == "DummyChecker":
library_name = "stable_diffusion"
class_name = "FlaxStableDiffusionSafetyChecker"
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None loaded_sub_model = None
@ -422,11 +408,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype) loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel): elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time if from_pt:
if name == "safety_checker":
loaded_sub_model = DummyChecker()
loaded_params = {}
elif from_pt:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here # TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt) loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
loaded_params = loaded_sub_model.params loaded_params = loaded_sub_model.params

View File

@ -1,8 +1,14 @@
from functools import partial
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import numpy as np
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
@ -77,60 +83,44 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
) )
return text_input.input_ids return text_input.input_ids
def __call__( def _get_safety_scores(self, features, params):
special_cos_dist, cos_dist = self.safety_checker(features, params)
return (special_cos_dist, cos_dist)
def _run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True
pil_images = [Image.fromarray(image) for image in images]
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
if jit:
features = shard(features)
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
special_cos_dist = unshard(special_cos_dist)
cos_dist = unshard(cos_dist)
safety_model_params = unreplicate(safety_model_params)
else:
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
images, has_nsfw = self.safety_checker.filtered_with_scores(
special_cos_dist,
cos_dist,
images,
safety_model_params,
)
return images, has_nsfw
def _generate(
self, self,
prompt_ids: jnp.array, prompt_ids: jnp.array,
params: Union[Dict, FrozenDict], params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey, prng_seed: jax.random.PRNGKey,
num_inference_steps: Optional[int] = 50, num_inference_steps: int = 50,
height: Optional[int] = 512, height: int = 512,
width: Optional[int] = 512, width: int = 512,
guidance_scale: Optional[float] = 7.5, guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None, latents: Optional[jnp.array] = None,
return_dict: bool = True,
debug: bool = False, debug: bool = False,
**kwargs,
): ):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0: if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@ -203,21 +193,106 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# scale and decode the image latents with vae # scale and decode the image latents with vae
latents = 1 / 0.18215 * latents latents = 1 / 0.18215 * latents
# TODO: check when flax vae gets merged into main
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1) image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image
# image = jnp.asarray(image).transpose(0, 2, 3, 1) def __call__(
# run safety checker self,
# TODO: check when flax safety checker gets merged into main prompt_ids: jnp.array,
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") params: Union[Dict, FrozenDict],
# image, has_nsfw_concept = self.safety_checker( prng_seed: jax.random.PRNGKey,
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"] num_inference_steps: int = 50,
# ) height: int = 512,
has_nsfw_concept = False width: int = 512,
guidance_scale: float = 7.5,
latents: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if jit:
images = _p_generate(
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
else:
images = self._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
safety_params = params["safety_checker"]
images = (images * 255).round().astype("uint8")
images = np.asarray(images).reshape(-1, height, width, 3)
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (images, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
):
return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
@partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_safety_scores(pipe, features, params):
return pipe._get_safety_scores(features, params)
def unshard(x: jnp.ndarray):
# einops.rearrange(x, 'd b ... -> (d b) ...')
num_devices, batch_size = x.shape[:2]
rest = x.shape[2:]
return x.reshape(num_devices * batch_size, *rest)