Flax safety checker (#825)
* Remove set_format in Flax pipeline. * Remove DummyChecker. * Run safety_checker in pipeline. * Don't pmap on every call. We could have decorated `generate` with `pmap`, but I wanted to keep it in case someone wants to invoke it in non-parallel mode. * Remove commented line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Replicate outside __call__, prepare for optional jitting. * Remove unnecessary clipping. As suggested by @kashif. * Do not jit unless requested. * Send all args to generate. * make style * Remove unused imports. * Fix docstring. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
e713346ad1
commit
78db11dbf3
|
@ -62,11 +62,6 @@ for library in LOADABLE_CLASSES:
|
|||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
class DummyChecker:
|
||||
def __init__(self):
|
||||
self.dummy = True
|
||||
|
||||
|
||||
def import_flax_or_no_model(module, class_name):
|
||||
try:
|
||||
# 1. First make sure that if a Flax object is present, import this one
|
||||
|
@ -177,10 +172,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
# TODO(Patrick, Suraj): to delete after
|
||||
if isinstance(sub_model, DummyChecker):
|
||||
continue
|
||||
|
||||
save_method = getattr(sub_model, save_method_name)
|
||||
expects_params = "params" in set(inspect.signature(save_method).parameters.keys())
|
||||
|
||||
|
@ -194,7 +185,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
|
||||
Instantiate a Flax diffusion pipeline from pre-trained pipeline weights.
|
||||
|
||||
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
|
||||
|
||||
|
@ -349,11 +340,6 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# TODO(Patrick, Suraj) - delete later
|
||||
if class_name == "DummyChecker":
|
||||
library_name = "stable_diffusion"
|
||||
class_name = "FlaxStableDiffusionSafetyChecker"
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
|
||||
|
@ -422,11 +408,7 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
|||
loaded_sub_model, loaded_params = load_method(loadable_folder, from_pt=from_pt, dtype=dtype)
|
||||
params[name] = loaded_params
|
||||
elif is_transformers_available() and issubclass(class_obj, FlaxPreTrainedModel):
|
||||
# make sure we don't initialize the weights to save time
|
||||
if name == "safety_checker":
|
||||
loaded_sub_model = DummyChecker()
|
||||
loaded_params = {}
|
||||
elif from_pt:
|
||||
if from_pt:
|
||||
# TODO(Suraj): Fix this in Transformers. We should be able to use `_do_init=False` here
|
||||
loaded_sub_model = load_method(loadable_folder, from_pt=from_pt)
|
||||
loaded_params = loaded_sub_model.params
|
||||
|
|
|
@ -1,8 +1,14 @@
|
|||
from functools import partial
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training.common_utils import shard
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel
|
||||
|
||||
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
|
||||
|
@ -77,60 +83,44 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
)
|
||||
return text_input.input_ids
|
||||
|
||||
def __call__(
|
||||
def _get_safety_scores(self, features, params):
|
||||
special_cos_dist, cos_dist = self.safety_checker(features, params)
|
||||
return (special_cos_dist, cos_dist)
|
||||
|
||||
def _run_safety_checker(self, images, safety_model_params, jit=False):
|
||||
# safety_model_params should already be replicated when jit is True
|
||||
pil_images = [Image.fromarray(image) for image in images]
|
||||
features = self.feature_extractor(pil_images, return_tensors="np").pixel_values
|
||||
|
||||
if jit:
|
||||
features = shard(features)
|
||||
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
|
||||
special_cos_dist = unshard(special_cos_dist)
|
||||
cos_dist = unshard(cos_dist)
|
||||
safety_model_params = unreplicate(safety_model_params)
|
||||
else:
|
||||
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
|
||||
|
||||
images, has_nsfw = self.safety_checker.filtered_with_scores(
|
||||
special_cos_dist,
|
||||
cos_dist,
|
||||
images,
|
||||
safety_model_params,
|
||||
)
|
||||
return images, has_nsfw
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: Optional[int] = 50,
|
||||
height: Optional[int] = 512,
|
||||
width: Optional[int] = 512,
|
||||
guidance_scale: Optional[float] = 7.5,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
guidance_scale: float = 7.5,
|
||||
latents: Optional[jnp.array] = None,
|
||||
return_dict: bool = True,
|
||||
debug: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
|
@ -203,21 +193,106 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
# TODO: check when flax vae gets merged into main
|
||||
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
|
||||
|
||||
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
|
||||
return image
|
||||
|
||||
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
|
||||
# run safety checker
|
||||
# TODO: check when flax safety checker gets merged into main
|
||||
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
# image, has_nsfw_concept = self.safety_checker(
|
||||
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
|
||||
# )
|
||||
has_nsfw_concept = False
|
||||
def __call__(
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.PRNGKey,
|
||||
num_inference_steps: int = 50,
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
guidance_scale: float = 7.5,
|
||||
latents: jnp.array = None,
|
||||
return_dict: bool = True,
|
||||
jit: bool = False,
|
||||
debug: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*, defaults to 512):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to 512):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`jnp.array`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
jit (`bool`, defaults to `False`):
|
||||
Whether to run `pmap` versions of the generation and safety scoring functions. NOTE: This argument
|
||||
exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a future release.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] instead of
|
||||
a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.FlaxStableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a
|
||||
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
||||
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
||||
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
if jit:
|
||||
images = _p_generate(
|
||||
self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
)
|
||||
else:
|
||||
images = self._generate(
|
||||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
)
|
||||
|
||||
safety_params = params["safety_checker"]
|
||||
images = (images * 255).round().astype("uint8")
|
||||
images = np.asarray(images).reshape(-1, height, width, 3)
|
||||
images, has_nsfw_concept = self._run_safety_checker(images, safety_params, jit)
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
return (images, has_nsfw_concept)
|
||||
|
||||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
|
||||
# TODO: maybe use a config dict instead of so many static argnums
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
|
||||
def _p_generate(
|
||||
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
):
|
||||
return pipe._generate(
|
||||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
|
||||
)
|
||||
|
||||
|
||||
@partial(jax.pmap, static_broadcasted_argnums=(0,))
|
||||
def _p_get_safety_scores(pipe, features, params):
|
||||
return pipe._get_safety_scores(features, params)
|
||||
|
||||
|
||||
def unshard(x: jnp.ndarray):
|
||||
# einops.rearrange(x, 'd b ... -> (d b) ...')
|
||||
num_devices, batch_size = x.shape[:2]
|
||||
rest = x.shape[2:]
|
||||
return x.reshape(num_devices * batch_size, *rest)
|
||||
|
|
Loading…
Reference in New Issue