Flax safety checker (#825)

* Remove set_format in Flax pipeline.

* Remove DummyChecker.

* Run safety_checker in pipeline.

* Don't pmap on every call.

We could have decorated `generate` with `pmap`, but I wanted to keep it
in case someone wants to invoke it in non-parallel mode.

* Remove commented line

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Replicate outside __call__, prepare for optional jitting.

* Remove unnecessary clipping.

As suggested by @kashif.

* Do not jit unless requested.

* Send all args to generate.

* make style

* Remove unused imports.

* Fix docstring.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Pedro Cuenca 2022-10-13 17:01:47 +02:00 committed by GitHub
parent e713346ad1
commit 78db11dbf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 135 additions and 78 deletions

View File

@ -62,11 +62,6 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
class DummyChecker:
def __init__(self):
self.dummy = True
def import_flax_or_no_model(module, class_name):
try:
# 1. First make sure that if a Flax object is present, import this one
@ -177,10 +172,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
if save_method_name is not None:
break
# TODO(Patrick, Suraj): to delete after
if isinstance(sub_model, DummyChecker):
continue
save_method = getattr(sub_model, save_method_name)
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
@ -194,7 +185,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
@ -349,11 +340,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# TODO(Patrick, Suraj) - delete later
if class_name == "DummyChecker":
library_name = "stable_diffusion"
class_name = "FlaxStableDiffusionSafetyChecker"
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
@ -422,11 +408,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
params[name] = loaded_params
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
# make sure we don't initialize the weights to save time
if name == "safety_checker":
loaded_sub_model = DummyChecker()
loaded_params = {}
elif from_pt:
if from_pt:
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
loaded_params = loaded_sub_model.params

View File

@ -1,8 +1,14 @@
from functools import partial
from typing import Dict, List, Optional, Union
import numpy as np
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from flax.jax_utils import unreplicate
from flax.training.common_utils import shard
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
@ -77,60 +83,44 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
)
return text_input.input_ids
def __call__(
def _get_safety_scores(self, features, params):
special_cos_dist, cos_dist = self.safety_checker(features, params)
return (special_cos_dist, cos_dist)
def _run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True
pil_images = [Image.fromarray(image) for image in images]
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
if jit:
features = shard(features)
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
special_cos_dist = unshard(special_cos_dist)
cos_dist = unshard(cos_dist)
safety_model_params = unreplicate(safety_model_params)
else:
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
images, has_nsfw = self.safety_checker.filtered_with_scores(
special_cos_dist,
cos_dist,
images,
safety_model_params,
)
return images, has_nsfw
def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: Optional[int] = 50,
height: Optional[int] = 512,
width: Optional[int] = 512,
guidance_scale: Optional[float] = 7.5,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
latents: Optional[jnp.array] = None,
return_dict: bool = True,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@ -203,21 +193,106 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
# TODO: check when flax vae gets merged into main
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
# run safety checker
# TODO: check when flax safety checker gets merged into main
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
# image, has_nsfw_concept = self.safety_checker(
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
# )
has_nsfw_concept = False
def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.PRNGKey,
num_inference_steps: int = 50,
height: int = 512,
width: int = 512,
guidance_scale: float = 7.5,
latents: jnp.array = None,
return_dict: bool = True,
jit: bool = False,
debug: bool = False,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`jnp.array`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
a plain tuple.
Returns:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
element is a list of `bool`s denoting whether the corresponding generated image likely represents
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
"""
if jit:
images = _p_generate(
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
else:
images = self._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
safety_params = params["safety_checker"]
images = (images * 255).round().astype("uint8")
images = np.asarray(images).reshape(-1, height, width, 3)
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
if not return_dict:
return (image, has_nsfw_concept)
return (images, has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
):
return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)
@partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_safety_scores(pipe, features, params):
return pipe._get_safety_scores(features, params)
def unshard(x: jnp.ndarray):
# einops.rearrange(x, 'd b ... -> (d b) ...')
num_devices, batch_size = x.shape[:2]
rest = x.shape[2:]
return x.reshape(num_devices * batch_size, *rest)