Flax safety checker (#825)
* Remove set_format in Flax pipeline. * Remove DummyChecker. * Run safety_checker in pipeline. * Don't pmap on every call. We could have decorated `generate` with `pmap`, but I wanted to keep it in case someone wants to invoke it in non-parallel mode. * Remove commented line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Replicate outside __call__, prepare for optional jitting. * Remove unnecessary clipping. As suggested by @kashif. * Do not jit unless requested. * Send all args to generate. * make style * Remove unused imports. * Fix docstring. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
e713346ad1
commit
78db11dbf3
|
@ -62,11 +62,6 @@ for library in LOADABLE_CLASSES:
|
||||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||||
|
|
||||||
|
|
||||||
class DummyChecker:
|
|
||||||
def __init__(self):
|
|
||||||
self.dummy = True
|
|
||||||
|
|
||||||
|
|
||||||
def import_flax_or_no_model(module, class_name):
|
def import_flax_or_no_model(module, class_name):
|
||||||
try:
|
try:
|
||||||
# 1. First make sure that if a Flax object is present, import this one
|
# 1. First make sure that if a Flax object is present, import this one
|
||||||
|
@ -177,10 +172,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
if save_method_name is not None:
|
if save_method_name is not None:
|
||||||
break
|
break
|
||||||
|
|
||||||
# TODO(Patrick, Suraj): to delete after
|
|
||||||
if isinstance(sub_model, DummyChecker):
|
|
||||||
continue
|
|
||||||
|
|
||||||
save_method = getattr(sub_model, save_method_name)
|
save_method = getattr(sub_model, save_method_name)
|
||||||
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
|
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
|
||||||
|
|
||||||
|
@ -194,7 +185,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
|
||||||
|
|
||||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||||
|
|
||||||
|
@ -349,11 +340,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
# 3. Load each module in the pipeline
|
# 3. Load each module in the pipeline
|
||||||
for name, (library_name, class_name) in init_dict.items():
|
for name, (library_name, class_name) in init_dict.items():
|
||||||
# TODO(Patrick, Suraj) - delete later
|
|
||||||
if class_name == "DummyChecker":
|
|
||||||
library_name = "stable_diffusion"
|
|
||||||
class_name = "FlaxStableDiffusionSafetyChecker"
|
|
||||||
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
loaded_sub_model = None
|
loaded_sub_model = None
|
||||||
|
|
||||||
|
@ -422,11 +408,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||||
params[name] = loaded_params
|
params[name] = loaded_params
|
||||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||||
# make sure we don't initialize the weights to save time
|
if from_pt:
|
||||||
if name == "safety_checker":
|
|
||||||
loaded_sub_model = DummyChecker()
|
|
||||||
loaded_params = {}
|
|
||||||
elif from_pt:
|
|
||||||
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
||||||
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
||||||
loaded_params = loaded_sub_model.params
|
loaded_params = loaded_sub_model.params
|
||||||
|
|
|
@ -1,8 +1,14 @@
|
||||||
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from flax.core.frozen_dict import FrozenDict
|
from flax.core.frozen_dict import FrozenDict
|
||||||
|
from flax.jax_utils import unreplicate
|
||||||
|
from flax.training.common_utils import shard
|
||||||
|
from PIL import Image
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||||
|
|
||||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||||
|
@ -77,60 +83,44 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||||
)
|
)
|
||||||
return text_input.input_ids
|
return text_input.input_ids
|
||||||
|
|
||||||
def __call__(
|
def _get_safety_scores(self, features, params):
|
||||||
|
special_cos_dist, cos_dist = self.safety_checker(features, params)
|
||||||
|
return (special_cos_dist, cos_dist)
|
||||||
|
|
||||||
|
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||||
|
# safety_model_params should already be replicated when jit is True
|
||||||
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
|
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
||||||
|
|
||||||
|
if jit:
|
||||||
|
features = shard(features)
|
||||||
|
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
|
||||||
|
special_cos_dist = unshard(special_cos_dist)
|
||||||
|
cos_dist = unshard(cos_dist)
|
||||||
|
safety_model_params = unreplicate(safety_model_params)
|
||||||
|
else:
|
||||||
|
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
|
||||||
|
|
||||||
|
images, has_nsfw = self.safety_checker.filtered_with_scores(
|
||||||
|
special_cos_dist,
|
||||||
|
cos_dist,
|
||||||
|
images,
|
||||||
|
safety_model_params,
|
||||||
|
)
|
||||||
|
return images, has_nsfw
|
||||||
|
|
||||||
|
def _generate(
|
||||||
self,
|
self,
|
||||||
prompt_ids: jnp.array,
|
prompt_ids: jnp.array,
|
||||||
params: Union[Dict, FrozenDict],
|
params: Union[Dict, FrozenDict],
|
||||||
prng_seed: jax.random.PRNGKey,
|
prng_seed: jax.random.PRNGKey,
|
||||||
num_inference_steps: Optional[int] = 50,
|
num_inference_steps: int = 50,
|
||||||
height: Optional[int] = 512,
|
height: int = 512,
|
||||||
width: Optional[int] = 512,
|
width: int = 512,
|
||||||
guidance_scale: Optional[float] = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
latents: Optional[jnp.array] = None,
|
latents: Optional[jnp.array] = None,
|
||||||
return_dict: bool = True,
|
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
**kwargs,
|
|
||||||
):
|
):
|
||||||
r"""
|
|
||||||
Function invoked when calling the pipeline for generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (`str` or `List[str]`):
|
|
||||||
The prompt or prompts to guide the image generation.
|
|
||||||
height (`int`, *optional*, defaults to 512):
|
|
||||||
The height in pixels of the generated image.
|
|
||||||
width (`int`, *optional*, defaults to 512):
|
|
||||||
The width in pixels of the generated image.
|
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
||||||
expense of slower inference.
|
|
||||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
||||||
usually at the expense of lower image quality.
|
|
||||||
generator (`torch.Generator`, *optional*):
|
|
||||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
|
||||||
deterministic.
|
|
||||||
latents (`jnp.array`, *optional*):
|
|
||||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
|
||||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
||||||
tensor will ge generated by sampling using the supplied random `generator`.
|
|
||||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
||||||
The output format of the generate image. Choose between
|
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
|
||||||
a plain tuple.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
|
||||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
|
||||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
|
||||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
|
||||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
|
||||||
"""
|
|
||||||
if height % 8 != 0 or width % 8 != 0:
|
if height % 8 != 0 or width % 8 != 0:
|
||||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||||
|
|
||||||
|
@ -203,21 +193,106 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||||
|
|
||||||
# scale and decode the image latents with vae
|
# scale and decode the image latents with vae
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
# TODO: check when flax vae gets merged into main
|
|
||||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||||
|
|
||||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||||
|
return image
|
||||||
|
|
||||||
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
|
def __call__(
|
||||||
# run safety checker
|
self,
|
||||||
# TODO: check when flax safety checker gets merged into main
|
prompt_ids: jnp.array,
|
||||||
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
params: Union[Dict, FrozenDict],
|
||||||
# image, has_nsfw_concept = self.safety_checker(
|
prng_seed: jax.random.PRNGKey,
|
||||||
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
|
num_inference_steps: int = 50,
|
||||||
# )
|
height: int = 512,
|
||||||
has_nsfw_concept = False
|
width: int = 512,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
latents: jnp.array = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
jit: bool = False,
|
||||||
|
debug: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`):
|
||||||
|
The prompt or prompts to guide the image generation.
|
||||||
|
height (`int`, *optional*, defaults to 512):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`int`, *optional*, defaults to 512):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
generator (`torch.Generator`, *optional*):
|
||||||
|
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||||
|
deterministic.
|
||||||
|
latents (`jnp.array`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
jit (`bool`, defaults to `False`):
|
||||||
|
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||||
|
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||||
|
a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||||
|
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||||
|
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||||
|
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||||
|
"""
|
||||||
|
if jit:
|
||||||
|
images = _p_generate(
|
||||||
|
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
images = self._generate(
|
||||||
|
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||||
|
)
|
||||||
|
|
||||||
|
safety_params = params["safety_checker"]
|
||||||
|
images = (images * 255).round().astype("uint8")
|
||||||
|
images = np.asarray(images).reshape(-1, height, width, 3)
|
||||||
|
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (image, has_nsfw_concept)
|
return (images, has_nsfw_concept)
|
||||||
|
|
||||||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: maybe use a config dict instead of so many static argnums
|
||||||
|
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
|
||||||
|
def _p_generate(
|
||||||
|
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||||
|
):
|
||||||
|
return pipe._generate(
|
||||||
|
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||||
|
def _p_get_safety_scores(pipe, features, params):
|
||||||
|
return pipe._get_safety_scores(features, params)
|
||||||
|
|
||||||
|
|
||||||
|
def unshard(x: jnp.ndarray):
|
||||||
|
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
||||||
|
num_devices, batch_size = x.shape[:2]
|
||||||
|
rest = x.shape[2:]
|
||||||
|
return x.reshape(num_devices * batch_size, *rest)
|
||||||
|
|
Loading…
Reference in New Issue